From 63c28ea39548b9f13d502cd357cae7193d3b82fb Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 9 Jan 2026 13:38:33 +0100 Subject: [PATCH] add cmd arg --- .../rac/rac_data_collection_openarms_rtc.py | 145 +++++++++--------- 1 file changed, 70 insertions(+), 75 deletions(-) diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index 9568d2e92..bf3b017f7 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -20,6 +20,13 @@ Controls: Usage: python examples/rac/rac_data_collection_openarms_rtc.py \ + --robot.type=openarms_follower \ + --robot.port_right=can0 \ + --robot.port_left=can1 \ + --robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0}, right_wrist: {type: opencv, index_or_path: 2}}" \ + --teleop.type=openarms_mini \ + --teleop.port_right=/dev/ttyUSB0 \ + --teleop.port_left=/dev/ttyUSB1 \ --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ --dataset.repo_id=my_user/rac_rtc_dataset \ --dataset.single_task="Pick up the cube" @@ -37,7 +44,8 @@ from threading import Event, Lock, Thread import torch from torch import Tensor -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule @@ -49,10 +57,10 @@ from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.latency_tracker import LatencyTracker from lerobot.processor import make_default_processors -from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig -from lerobot.robots.openarms.openarms_follower import OpenArmsFollower -from lerobot.teleoperators import make_teleoperator_from_config -from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig +from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401 +from lerobot.teleoperators import TeleoperatorConfig, make_teleoperator_from_config +from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401 from lerobot.utils.control_utils import is_headless from lerobot.utils.hub import HubMixin from lerobot.utils.robot_utils import precise_sleep @@ -64,7 +72,7 @@ logger = logging.getLogger(__name__) class RobotWrapper: """Thread-safe wrapper for robot operations.""" - def __init__(self, robot: OpenArmsFollower): + def __init__(self, robot: Robot): self.robot = robot self.lock = Lock() @@ -88,48 +96,44 @@ class RobotWrapper: def name(self) -> str: return self.robot.name -DEFAULT_FPS = 30 -DEFAULT_EPISODE_TIME_SEC = 120 -DEFAULT_NUM_EPISODES = 50 +@dataclass +class RaCRTCDatasetConfig: + """Dataset configuration for RaC + RTC.""" + repo_id: str = "lerobot/rac_rtc_openarms" + single_task: str = "task" + root: str | Path | None = None + fps: int = 30 + episode_time_s: float = 120 + num_episodes: int = 50 + video: bool = True + push_to_hub: bool = True + private: bool = True + num_image_writer_processes: int = 0 + num_image_writer_threads_per_camera: int = 4 -DEFAULT_CAMERA_CONFIG = { - "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS), - "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS), - "base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS), -} @dataclass class RaCRTCConfig(HubMixin): """Configuration for RaC data collection with RTC.""" + robot: RobotConfig + teleop: TeleoperatorConfig + dataset: RaCRTCDatasetConfig policy: PreTrainedConfig | None = None rtc: RTCConfig = field( default_factory=lambda: RTCConfig( enabled=True, - execution_horizon=10, - max_guidance_weight=10.0, - prefix_attention_schedule=RTCAttentionSchedule.EXP, + execution_horizon=20, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, ) ) - dataset_repo_id: str = "lerobot/rac_rtc_openarms" - task: str = "task" - - fps: float = DEFAULT_FPS - episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC - num_episodes: int = DEFAULT_NUM_EPISODES - - follower_left_port: str = "can0" - follower_right_port: str = "can1" - teleop_left_port: str = "/dev/ttyUSB1" - teleop_right_port: str = "/dev/ttyUSB0" - device: str = "cuda" action_queue_size_to_get_new_actions: int = 30 - - interpolation: bool = False - push_to_hub: bool = True + interpolation: bool = True + play_sounds: bool = True def __post_init__(self): policy_path = parser.get_path_arg("policy") @@ -257,13 +261,14 @@ def get_actions_thread( shutdown_event: Event, cfg: RaCRTCConfig, policy_active: Event, + fps: int, ): """Thread for async action generation via RTC.""" try: logger.info("[GET_ACTIONS] Starting RTC action generation thread") latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / cfg.fps + time_per_chunk = 1.0 / fps hw_features = hw_to_dataset_features(robot.observation_features, "observation") policy_device = policy.config.device @@ -311,7 +316,7 @@ def get_actions_thread( obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) - obs_with_policy_features["task"] = [cfg.task] + obs_with_policy_features["task"] = [cfg.dataset.single_task] obs_with_policy_features["robot_type"] = robot.name preprocessed_obs = preprocessor(obs_with_policy_features) @@ -363,14 +368,16 @@ def main(cfg: RaCRTCConfig): """Main RaC + RTC data collection.""" init_logging() + fps = cfg.dataset.fps + print("=" * 65) print(" RaC Data Collection with RTC - OpenArms") print("=" * 65) print(f" Policy: {cfg.policy.pretrained_path}") - print(f" Dataset: {cfg.dataset_repo_id}") - print(f" Task: {cfg.task}") - print(f" Policy Hz: {cfg.fps}") - print(f" Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}") + print(f" Dataset: {cfg.dataset.repo_id}") + print(f" Task: {cfg.dataset.single_task}") + print(f" Policy Hz: {fps}") + print(f" Robot Hz: {fps * 2 if cfg.interpolation else fps}") print(f" Interpolation: {cfg.interpolation}") print(f" RTC Enabled: {cfg.rtc.enabled}") print("=" * 65) @@ -389,26 +396,12 @@ def main(cfg: RaCRTCConfig): shutdown_event = Event() policy_active = Event() - follower_config = OpenArmsFollowerConfig( - port_left=cfg.follower_left_port, - port_right=cfg.follower_right_port, - can_interface="socketcan", - id="openarms_follower", - disable_torque_on_disconnect=True, - max_relative_target=10.0, - cameras=DEFAULT_CAMERA_CONFIG, - ) - - follower = OpenArmsFollower(follower_config) - follower.connect(calibrate=False) + follower = make_robot_from_config(cfg.robot) + follower.connect() robot = RobotWrapper(follower) logger.info("Robot connected") - teleop_config = OpenArmsMiniConfig( - port_right=cfg.teleop_right_port, - port_left=cfg.teleop_left_port, - ) - teleop = make_teleoperator_from_config(teleop_config) + teleop = make_teleoperator_from_config(cfg.teleop) teleop.connect() teleop.disable_torque() logger.info("Teleop connected") @@ -420,23 +413,25 @@ def main(cfg: RaCRTCConfig): aggregate_pipeline_dataset_features( pipeline=teleop_proc, initial_features=create_initial_features(action=action_features_hw), - use_videos=True, + use_videos=cfg.dataset.video, ), aggregate_pipeline_dataset_features( pipeline=obs_proc, initial_features=create_initial_features(observation=follower.observation_features), - use_videos=True, + use_videos=cfg.dataset.video, ), ) dataset = LeRobotDataset.create( - repo_id=cfg.dataset_repo_id, - fps=int(cfg.fps), + repo_id=cfg.dataset.repo_id, + fps=fps, + root=cfg.dataset.root, features=dataset_features, robot_type=follower.name, - use_videos=True, - image_writer_processes=0, - image_writer_threads=12, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(follower.cameras if hasattr(follower, "cameras") else []), ) dataset_lock = Lock() @@ -453,7 +448,7 @@ def main(cfg: RaCRTCConfig): get_actions_t = Thread( target=get_actions_thread, - args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active), + args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active, fps), daemon=True, name="GetActions", ) @@ -472,15 +467,15 @@ def main(cfg: RaCRTCConfig): if cfg.interpolation: interp_factor = 2 - robot_interval = 1.0 / (cfg.fps * interp_factor) + robot_interval = 1.0 / (fps * interp_factor) else: interp_factor = 1 - robot_interval = 1.0 / cfg.fps + robot_interval = 1.0 / fps try: recorded = 0 - while recorded < cfg.num_episodes and not events["stop_recording"]: - log_say(f"RaC episode {recorded + 1}", play_sounds=True) + while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"RaC episode {recorded + 1}", play_sounds=cfg.play_sounds) move_robot_to_zero(follower, duration_s=2.0) @@ -502,7 +497,7 @@ def main(cfg: RaCRTCConfig): policy_active.set() episode_start = time.perf_counter() - while (time.perf_counter() - episode_start) < cfg.episode_time_sec: + while (time.perf_counter() - episode_start) < cfg.dataset.episode_time_s: loop_start = time.perf_counter() if events["exit_early"]: @@ -531,7 +526,7 @@ def main(cfg: RaCRTCConfig): robot.send_action(robot_action) action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action") - frame = {**obs_frame, **action_frame, "task": cfg.task} + frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} frame_buffer.append(frame) elif events["policy_paused"]: @@ -567,7 +562,7 @@ def main(cfg: RaCRTCConfig): robot_send_count += 1 action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action") - frame = {**obs_frame, **action_frame, "task": cfg.task} + frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} frame_buffer.append(frame) now = time.perf_counter() @@ -588,7 +583,7 @@ def main(cfg: RaCRTCConfig): policy_active.clear() if events["rerecord_episode"]: - log_say("Re-recording", play_sounds=True) + log_say("Re-recording", play_sounds=cfg.play_sounds) events["rerecord_episode"] = False continue @@ -627,12 +622,12 @@ def main(cfg: RaCRTCConfig): if "gripper" in key: action[key] = -0.65 * action[key] robot.send_action(action) - precise_sleep(1 / cfg.fps) + precise_sleep(1 / fps) events["in_reset"] = False events["start_next_episode"] = False - log_say("Recording complete", play_sounds=True, blocking=True) + log_say("Recording complete", play_sounds=cfg.play_sounds, blocking=True) except KeyboardInterrupt: logger.info("Interrupted by user") @@ -651,9 +646,9 @@ def main(cfg: RaCRTCConfig): listener.stop() dataset.finalize() - if cfg.push_to_hub: + if cfg.dataset.push_to_hub: logger.info("Pushing to hub...") - dataset.push_to_hub(private=True) + dataset.push_to_hub(private=cfg.dataset.private) logger.info("Done")