mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
feat(umi): add EE replay viewer, URDF meshes, and evaluate script updates
- Add replay.py script and replay_viewer.html for browser-based EE trajectory visualization from glannuzel/grabette-dataset - Add viewer.html for interactive URDF inspection - Move OpenArm URDF and meshes into openarm_follower/urdf/ - Add virtual EE target frame (openarm_right_ee_target) at 7cm from link7 - Adapt evaluate.py for single right-arm OpenArm with wrist camera - Update docs with replay viewer usage - Update openarm_follower config, driver, and kinematic processor Made-with: Cursor
This commit is contained in:
@@ -86,11 +86,12 @@ HF_MODEL_ID = "pepijn223/grabette-umi-pi0"
|
||||
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
|
||||
LATENCY_SKIP_STEPS = 0
|
||||
|
||||
# EE feature keys produced by ForwardKinematicsJointsToEE
|
||||
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
# EE feature keys produced by ForwardKinematicsJointsToEE (arm pose only).
|
||||
# Gripper joints use absolute position control, not EE-relative.
|
||||
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz"]
|
||||
|
||||
URDF_PATH = "./SO101/so101_new_calib.urdf"
|
||||
URDF_EE_FRAME = "gripper_frame_link"
|
||||
URDF_PATH = "src/lerobot/robots/openarm_follower/urdf/openarm_bimanual_pybullet.urdf"
|
||||
URDF_EE_FRAME = "openarm_right_link7"
|
||||
|
||||
|
||||
def main():
|
||||
@@ -101,38 +102,47 @@ def main():
|
||||
side="right",
|
||||
cameras=camera_config,
|
||||
max_relative_target=8.0,
|
||||
gripper_port="/dev/ttyUSB0",
|
||||
)
|
||||
robot = OpenArmFollower(robot_config)
|
||||
|
||||
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
||||
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
|
||||
|
||||
motor_names = list(robot.bus.motors.keys())
|
||||
arm_motor_names = list(robot.bus.motors.keys())
|
||||
gripper_motor_names = list(robot.gripper_bus.motors.keys())
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path=URDF_PATH,
|
||||
target_frame_name=URDF_EE_FRAME,
|
||||
joint_names=motor_names,
|
||||
joint_names=arm_motor_names,
|
||||
)
|
||||
|
||||
# FK: joint observation → EE observation (produces observation.state)
|
||||
# The policy starts from the robot's current EE pose (via FK below).
|
||||
# Relative actions are predicted as deltas from that pose, so no manual
|
||||
# re-centering is needed — the starting point is always the live EE tip.
|
||||
|
||||
# FK: joint observation → EE observation (produces observation.state).
|
||||
# gripper_names=[] means proximal/distal pass through as absolute positions.
|
||||
robot_joints_to_ee_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
motor_names=arm_motor_names,
|
||||
gripper_names=[],
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# IK: EE action → joint targets
|
||||
# IK: EE action → joint targets. Gripper actions are absolute and pass through.
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
motor_names=arm_motor_names,
|
||||
gripper_names=[],
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
@@ -162,7 +172,13 @@ def main():
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) for k in EE_KEYS}
|
||||
action={
|
||||
**{f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) for k in EE_KEYS},
|
||||
**{
|
||||
f"{g}.pos": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for g in gripper_motor_names
|
||||
},
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Replay a dataset episode in EE frame using a browser-based URDF viewer.
|
||||
|
||||
Extracts ``observation.pose`` from the dataset, saves a trajectory JSON file,
|
||||
then launches a local HTTP server and opens the replay viewer. The trajectory
|
||||
is re-centered so frame 0 starts at the OpenArm ``openarm_right_ee_target``
|
||||
EE tip (zero-joint pose).
|
||||
|
||||
Usage:
|
||||
python replay.py
|
||||
python replay.py --episode 3 --repo-id myuser/mydata
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
VIEWER_DIR = Path(__file__).resolve().parents[2] / "src/lerobot/robots/openarm_follower/urdf"
|
||||
TRAJECTORY_FILENAME = "trajectory_ep0.json"
|
||||
|
||||
|
||||
def extract_trajectory(repo_id: str, episode: int, output_path: Path) -> dict:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset(repo_id, episodes=[episode])
|
||||
poses = dataset.select_columns("observation.pose")
|
||||
actions = dataset.select_columns("action")
|
||||
|
||||
frames = []
|
||||
for i in range(dataset.num_frames):
|
||||
p = poses[i]["observation.pose"]
|
||||
a = actions[i]["action"]
|
||||
frames.append(
|
||||
{
|
||||
"x": float(p[0]),
|
||||
"y": float(p[1]),
|
||||
"z": float(p[2]),
|
||||
"ax": float(p[3]),
|
||||
"ay": float(p[4]),
|
||||
"az": float(p[5]),
|
||||
"proximal": float(a[0]),
|
||||
"distal": float(a[1]),
|
||||
}
|
||||
)
|
||||
payload = {"fps": dataset.fps, "num_frames": dataset.num_frames, "frames": frames}
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
print(f"Extracted {dataset.num_frames} frames at {dataset.fps} FPS → {output_path}")
|
||||
return payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Viewer mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def serve_and_open(directory: Path, port: int = 8765):
|
||||
os.chdir(directory)
|
||||
handler = http.server.SimpleHTTPRequestHandler
|
||||
httpd = http.server.HTTPServer(("", port), handler)
|
||||
url = f"http://localhost:{port}/replay_viewer.html"
|
||||
print(f"Serving at {url}")
|
||||
threading.Thread(target=lambda: webbrowser.open(url), daemon=True).start()
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nServer stopped.")
|
||||
httpd.server_close()
|
||||
|
||||
|
||||
def run_viewer(args):
|
||||
trajectory_path = VIEWER_DIR / TRAJECTORY_FILENAME
|
||||
if not trajectory_path.exists() or args.force:
|
||||
extract_trajectory(args.repo_id, args.episode, trajectory_path)
|
||||
else:
|
||||
print(f"Using cached trajectory at {trajectory_path} (pass --force to re-extract)")
|
||||
serve_and_open(VIEWER_DIR, args.port)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Replay a dataset episode in EE frame (URDF viewer)")
|
||||
parser.add_argument("--repo-id", default="glannuzel/grabette-dataset")
|
||||
parser.add_argument("--episode", type=int, default=0)
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument("--force", action="store_true", help="Re-extract trajectory even if cached")
|
||||
args = parser.parse_args()
|
||||
run_viewer(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user