用于生成演示、培训和评估的脚本都在 Scripts/ 文件夹中
DP3 通过 gen_demonstration 生成演示,即训练数据,例如:
bash scripts/gen_demonstration_adroit.sh hammer
这将在 Adroit 环境中生成锤子任务的演示。数据将自动保存在 3D-Diffusion-Policy/data/ 文件夹中
接下来将详细解释一下流程逻辑,以及分享一小部分在复现过程中写了一些测试脚本
目录
1 脚本解析
1.1 切换目录
1.2 定义任务变量
1.3 运行 Python 脚本
2 运行效果
3 可视化
3.1 数据结构
3.2 数据可视化
3.3 动态可视化
1 脚本解析
1.1 切换目录
cd third_party/VRL3/src
将工作目录切换到 third_party/VRL3/src
1.2 定义任务变量
task=${1}
从脚本运行时传入的第一个参数来获取任务名称,例如:
bash scripts/gen_demonstration_adroit.sh hammer,task
将被赋值为 "
hammer"
1.3 运行 Python 脚本
CUDA_VISIBLE_DEVICES=0 python gen_demonstration_expert.py --env_name $task \--num_episodes 10 \--root_dir "../../../3D-Diffusion-Policy/data/" \--expert_ckpt_path "../ckpts/vrl3_${task}.pt" \--img_size 84 \--not_use_multi_view \--use_point_crop
通过 Python 脚本 gen_demonstration_expert.py 生成演示数据,参数说明如下:
--env_name $task:指定任务名称(如 "door"、"hammer" 或 "pen")
--num_episodes 10:生成 10 个演示数据集
--root_dir "../../../3D-Diffusion-Policy/data/":设置数据存储的根目录
--expert_ckpt_path "../ckpts/vrl3_${task}.pt":指定专家模型的检查点路径,如下:
--img_size 84:设置图像尺寸为 84x84 像素
--not_use_multi_view:禁用多视角
--use_point_crop:启用点裁剪
2 运行效果
可以看到存储位置:Saved zarr file to ../../../3D-Diffusion-Policy/data/adroit_hammer_expert.zarr
3 可视化
在存储位置找到生成的训练数据,然后进行分析
3.1 数据结构
我们先来看一下生成数据的数据结构,执行代码:
import zarrpointcloud_path = "/home/yejiangchen/Desktop/Codes/3D-Diffusion-Policy-master/3D-Diffusion-Policy/data/adroit_hammer_expert.zarr"# 打开 zarr 数据集
pointcloud_dataset = zarr.open(pointcloud_path, mode='r')# 打印数据集的树状结构来查看所有路径
print(pointcloud_dataset.tree())
运行结果:
此处展开说明一下:
data 目录包含:动作数据 action、深度图数据 depth、图像数据 img、点云数据 point_cloud、状态数据 state
meta 目录包含:回合结束索引 episode_ends
1. action (1000, 26) float32:每个时间步机器人(智能体)采取的动作
1000:表示有 1000 个时间步(或帧)
26:表示每个时间步有 26 个动作特征值
2. depth (1000, 84, 84) float32:每帧对应的深度图信息,每个像素的值表示从传感器到物体表面的距离(深度)
1000:表示有 1000 帧深度图
84 x 84:深度图的分辨率为 84x84 像素
3. img (1000, 84, 84, 3) uint8:每帧的彩色图像,84x84的矩阵,每个元素是长度为3的向量
1000:表示有 1000 帧 RGB 图像
84 x 84:图像分辨率为 84x84 像素
3:表示 3 个颜色通道(Red, Green, Blue)
4. point_cloud (1000, 512, 6) float32:每帧的 3D 点云
1000:表示有 1000 帧点云数据
512:每帧有 512 个点
6:每个点有 6 个特征(通常是 x, y, z 坐标和 r, g, b 颜色值)
5. state (1000, 24) float32:每个时间步机器人(智能体)的状态
1000:表示有 1000 个时间步
24:表示每个时间步的状态特征维度为 24
6. meta/episode_ends (10,) int64:指示 10 个 episode(回合)的结束位置
10:表示有 10 个结束索引,即 1000 帧中,第 0 到第 99 帧属于第一个 episode,第 100 帧是第一个 episode 的结束,共 10 个 episode
3.2 数据可视化
既然经知道数据结构了,接下来就进行可视化,执行代码:
from flask import Flask, render_template_string, request
import zarr
import numpy as np
import plotly.graph_objs as go
import plotly.io as pio
import base64
import io
from PIL import Imageapp = Flask(__name__)# 加载 Zarr 数据集
data_path = "/home/yejiangchen/Desktop/Codes/3D-Diffusion-Policy-master/3D-Diffusion-Policy/data/adroit_hammer_expert.zarr"
dataset = zarr.open(data_path, mode='r')# 提取数据
pointclouds = dataset['data/point_cloud']
images = dataset['data/img']def generate_pointcloud_trace(frame_index):"""生成 Plotly 点云图的 Trace"""pointcloud = pointclouds[frame_index]x, y, z = pointcloud[:, 0], pointcloud[:, 1], pointcloud[:, 2]colors = ['rgb({},{},{})'.format(int(r), int(g), int(b)) for r, g, b in pointcloud[:, 3:6]]trace = go.Scatter3d(x=x,y=y,z=z,mode='markers',marker=dict(size=2,opacity=0.7,color=colors))return tracedef generate_image_data(frame_index):"""将图像数据转换为 base64 格式以便在 HTML 中显示"""img_data = images[frame_index]img = Image.fromarray(img_data)buffered = io.BytesIO()img.save(buffered, format="PNG")img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')return f"data:image/png;base64,{img_str}"@app.route('/')
def index():frame_index = int(request.args.get('frame', 0))max_frame_index = pointclouds.shape[0] - 1# 确保索引在有效范围内frame_index = max(0, min(frame_index, max_frame_index))# 生成点云和图像trace = generate_pointcloud_trace(frame_index)image_data = generate_image_data(frame_index)# 生成 Plotly 图表的 HTMLlayout = go.Layout(margin=dict(l=0, r=0, b=0, t=0))fig = go.Figure(data=[trace], layout=layout)pointcloud_html = pio.to_html(fig, full_html=False)return render_template_string("""<!DOCTYPE html><html><head><title>Dynamic Point Cloud Viewer</title></head><body><h1>Dynamic Point Cloud and Image Viewer</h1><form method="get"><label for="frame">Frame Index (0-{{ max_frame }}):</label><input type="number" id="frame" name="frame" min="0" max="{{ max_frame }}" value="{{ frame }}"><input type="submit" value="Load Frame"></form><h2>Point Cloud</h2><div>{{ pointcloud_html | safe }}</div><h2>Image</h2><img src="{{ image_data }}" alt="Frame Image" style="width:400px;"></body></html>""", pointcloud_html=pointcloud_html, image_data=image_data, frame=frame_index, max_frame=max_frame_index)if __name__ == '__main__':app.run(debug=True, use_reloader=False, port=5000)
运行结果:
此处展开说明一下:
此代码使用 Flask 和 Plotly 来实现一个网页界面,可以动态切换和查看不同帧的点云数据和图像数据,它支持:
动态查看点云数据:通过滑块来选择不同的帧进行点云可视化
动态查看图像数据:同时显示对应帧的图像数据
Flask 服务器:提供网页界面来交互和展示可视化结果
简单解析:
1. 数据加载:加载 Zarr 数据集路径下的 point_cloud 和 img 数据
2. generate_pointcloud_trace:生成 Plotly 的 3D 散点图(点云可视化)
3. generate_image_data:将图像数据转换为 base64 格式,以便在 HTML 中显示
4. Flask 路由 index:通过 GET 请求获取当前帧索引,动态生成点云和图像并嵌入 HTML
5. HTML 界面:提供输入框让用户选择帧索引,显示对应帧的点云和图像
6. 注意事项:如果 5000 端口被占用,可以更改为其他端口,如 5001
3.3 动态可视化
既然可以一帧一帧显示了,那能不能生成视频流看一下呢,执行以下代码:
import zarr
import cv2
import os# 加载 Zarr 数据集
data_path = "/home/yejiangchen/Desktop/Codes/3D-Diffusion-Policy-master/3D-Diffusion-Policy/data/adroit_hammer_expert.zarr"
dataset = zarr.open(data_path, mode='r')
images = dataset['data/img']# 视频保存路径
output_path = "output_video.avi"# 获取图像帧的高度和宽度
height, width, _ = images[0].shape# 定义视频编写器 (MJPG编码格式,帧率为30)
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height))# 将每一帧写入视频
for frame in images:frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)out.write(frame_bgr)# 释放视频写入对象
out.release()print(f"视频已保存到 {os.path.abspath(output_path)}")
运行结果:
小锤40......
此处展开说明一下:
此代码使用 cv2.VideoWriter 将 Zarr 数据集的图像帧保存为视频
视频编码器:使用 MJPG 编码保存为 .avi 格式,可以改用 XVID 或 MP4V 以保存为 .mp4
帧率:在 cv2.VideoWriter 中可以更改设置帧率
输出路径:生成的视频将保存在 output_video.avi