From 98c33a4748027422ccdb063919ff17fd7b2a62a7 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 9 Jan 2026 13:26:25 +0100 Subject: [PATCH] Add RaC with RTC --- .../rac/rac_data_collection_openarms_rtc.py | 663 ++++++++++++++++++ 1 file changed, 663 insertions(+) create mode 100644 examples/rac/rac_data_collection_openarms_rtc.py diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py new file mode 100644 index 000000000..9568d2e92 --- /dev/null +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python +""" +RaC (Recovery and Correction) Data Collection with RTC for OpenArms Robot. + +Combines RaC paradigm with Real-Time Chunking (RTC) for smooth policy execution. +RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive +motion despite high inference latency by asynchronously generating action chunks. + +The workflow: +1. Policy runs via RTC (async action generation) - teleop is idle +2. Press SPACE/right pedal to pause - teleop moves to match robot position +3. Press 'c'/left pedal to take control - human provides RECOVERY + CORRECTION +4. Press →/right pedal to end episode (save and continue to next) + +Controls: + SPACE/Right pedal - Pause policy + c/Left pedal - Take control (start correction) + →/Right pedal - End episode (when in correction mode) + ESC - Stop recording and push to hub + +Usage: + python examples/rac/rac_data_collection_openarms_rtc.py \ + --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ + --dataset.repo_id=my_user/rac_rtc_dataset \ + --dataset.single_task="Pick up the cube" +""" + +import logging +import math +import sys +import time +import traceback +from dataclasses import dataclass, field +from pathlib import Path +from threading import Event, Lock, Thread + +import torch +from torch import Tensor + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.processor import make_default_processors +from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig +from lerobot.robots.openarms.openarms_follower import OpenArmsFollower +from lerobot.teleoperators import make_teleoperator_from_config +from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig +from lerobot.utils.control_utils import is_headless +from lerobot.utils.hub import HubMixin +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import init_logging, log_say + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class RobotWrapper: + """Thread-safe wrapper for robot operations.""" + + def __init__(self, robot: OpenArmsFollower): + self.robot = robot + self.lock = Lock() + + def get_observation(self) -> dict[str, Tensor]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: dict) -> None: + with self.lock: + self.robot.send_action(action) + + @property + def observation_features(self) -> dict: + return self.robot.observation_features + + @property + def action_features(self) -> dict: + return self.robot.action_features + + @property + def name(self) -> str: + return self.robot.name + +DEFAULT_FPS = 30 +DEFAULT_EPISODE_TIME_SEC = 120 +DEFAULT_NUM_EPISODES = 50 + +DEFAULT_CAMERA_CONFIG = { + "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS), + "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS), + "base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS), +} + +@dataclass +class RaCRTCConfig(HubMixin): + """Configuration for RaC data collection with RTC.""" + + policy: PreTrainedConfig | None = None + + rtc: RTCConfig = field( + default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=10.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + ) + ) + + dataset_repo_id: str = "lerobot/rac_rtc_openarms" + task: str = "task" + + fps: float = DEFAULT_FPS + episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC + num_episodes: int = DEFAULT_NUM_EPISODES + + follower_left_port: str = "can0" + follower_right_port: str = "can1" + teleop_left_port: str = "/dev/ttyUSB1" + teleop_right_port: str = "/dev/ttyUSB0" + + device: str = "cuda" + action_queue_size_to_get_new_actions: int = 30 + + interpolation: bool = False + push_to_hub: bool = True + + def __post_init__(self): + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("policy.path is required") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + return ["policy"] + + +def init_keyboard_listener(events: dict): + """Initialize keyboard listener with RaC controls.""" + if is_headless(): + logger.warning("Headless environment - keyboard controls unavailable") + return None + + from pynput import keyboard + + def on_press(key): + try: + if events["in_reset"]: + if key == keyboard.Key.space or key == keyboard.Key.right: + events["start_next_episode"] = True + elif hasattr(key, "char") and key.char == "c": + events["start_next_episode"] = True + elif key == keyboard.Key.esc: + events["stop_recording"] = True + events["start_next_episode"] = True + else: + if key == keyboard.Key.space: + if not events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ⏸ PAUSED - Press 'c' to take control") + events["policy_paused"] = True + elif hasattr(key, "char") and key.char == "c": + if events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ▶ Taking control...") + events["start_correction"] = True + elif key == keyboard.Key.right: + print("\n[RaC] → End episode") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("\n[RaC] ← Re-record episode") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Key error: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + start_pedal_listener(events) + return listener + + +def start_pedal_listener(events: dict): + """Start foot pedal listener if available.""" + import threading + + try: + from evdev import InputDevice, ecodes + except ImportError: + return + + PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + KEY_LEFT = "KEY_A" + KEY_RIGHT = "KEY_C" + + def pedal_reader(): + try: + dev = InputDevice(PEDAL_DEVICE) + print(f"[Pedal] Connected: {dev.name}") + + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + + from evdev import categorize + + key = categorize(ev) + code = key.keycode + if isinstance(code, (list, tuple)): + code = code[0] + + if key.keystate != 1: + continue + + if events["in_reset"]: + if code in [KEY_LEFT, KEY_RIGHT]: + events["start_next_episode"] = True + else: + if code == KEY_RIGHT: + if events["correction_active"]: + events["exit_early"] = True + elif not events["policy_paused"]: + print("\n[Pedal] ⏸ PAUSED") + events["policy_paused"] = True + elif code == KEY_LEFT: + if events["policy_paused"] and not events["correction_active"]: + print("\n[Pedal] ▶ Taking control...") + events["start_correction"] = True + + except FileNotFoundError: + pass + except PermissionError: + logger.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}") + except Exception: + pass + + thread = threading.Thread(target=pedal_reader, daemon=True) + thread.start() + + +def get_actions_thread( + policy, + robot: RobotWrapper, + robot_observation_processor, + action_queue: ActionQueue, + shutdown_event: Event, + cfg: RaCRTCConfig, + policy_active: Event, +): + """Thread for async action generation via RTC.""" + try: + logger.info("[GET_ACTIONS] Starting RTC action generation thread") + + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.fps + + hw_features = hw_to_dataset_features(robot.observation_features, "observation") + policy_device = policy.config.device + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=None, + preprocessor_overrides={"device_processor": {"device": cfg.device}}, + ) + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + while not shutdown_event.is_set(): + if not policy_active.is_set(): + time.sleep(0.01) + continue + + if action_queue.qsize() <= get_actions_threshold: + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + obs = robot.get_observation() + obs_processed = robot_observation_processor(obs) + + obs_with_policy_features = build_dataset_frame( + hw_features, obs_processed, prefix="observation" + ) + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = ( + obs_with_policy_features[name].type(torch.float32) / 255 + ) + obs_with_policy_features[name] = ( + obs_with_policy_features[name].permute(2, 0, 1).contiguous() + ) + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + obs_with_policy_features["task"] = [cfg.task] + obs_with_policy_features["robot_type"] = robot.name + + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + action_queue.merge( + original_actions, postprocessed_actions, new_delay, action_index_before_inference + ) + else: + time.sleep(0.01) + + logger.info("[GET_ACTIONS] Shutting down") + except Exception as e: + logger.error(f"[GET_ACTIONS] Fatal exception: {e}") + logger.error(traceback.format_exc()) + shutdown_event.set() + sys.exit(1) + + +def move_robot_to_zero(robot, duration_s: float = 2.0, fps: int = 50): + """Smoothly move robot to zero position.""" + obs = robot.get_observation() + current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} + target_pos = {k: 0.0 for k in current_pos} + + print(f"[RaC] Moving robot to zero ({duration_s}s)...") + steps = int(duration_s * fps) + for step in range(steps + 1): + t = step / steps + interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos} + robot.send_action(interp_pos) + time.sleep(1 / fps) + + +@parser.wrap() +def main(cfg: RaCRTCConfig): + """Main RaC + RTC data collection.""" + init_logging() + + print("=" * 65) + print(" RaC Data Collection with RTC - OpenArms") + print("=" * 65) + print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Dataset: {cfg.dataset_repo_id}") + print(f" Task: {cfg.task}") + print(f" Policy Hz: {cfg.fps}") + print(f" Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}") + print(f" Interpolation: {cfg.interpolation}") + print(f" RTC Enabled: {cfg.rtc.enabled}") + print("=" * 65) + + events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + "policy_paused": False, + "correction_active": False, + "start_correction": False, + "in_reset": False, + "start_next_episode": False, + } + + shutdown_event = Event() + policy_active = Event() + + follower_config = OpenArmsFollowerConfig( + port_left=cfg.follower_left_port, + port_right=cfg.follower_right_port, + can_interface="socketcan", + id="openarms_follower", + disable_torque_on_disconnect=True, + max_relative_target=10.0, + cameras=DEFAULT_CAMERA_CONFIG, + ) + + follower = OpenArmsFollower(follower_config) + follower.connect(calibrate=False) + robot = RobotWrapper(follower) + logger.info("Robot connected") + + teleop_config = OpenArmsMiniConfig( + port_right=cfg.teleop_right_port, + port_left=cfg.teleop_left_port, + ) + teleop = make_teleoperator_from_config(teleop_config) + teleop.connect() + teleop.disable_torque() + logger.info("Teleop connected") + + teleop_proc, robot_proc, obs_proc = make_default_processors() + action_features_hw = {k: v for k, v in follower.action_features.items() if k.endswith(".pos")} + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_proc, + initial_features=create_initial_features(action=action_features_hw), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=obs_proc, + initial_features=create_initial_features(observation=follower.observation_features), + use_videos=True, + ), + ) + + dataset = LeRobotDataset.create( + repo_id=cfg.dataset_repo_id, + fps=int(cfg.fps), + features=dataset_features, + robot_type=follower.name, + use_videos=True, + image_writer_processes=0, + image_writer_threads=12, + ) + dataset_lock = Lock() + + policy_class = get_policy_class(cfg.policy.type) + config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) + policy.config.rtc_config = cfg.rtc + policy.init_rtc_processor() + policy = policy.to(cfg.device) + policy.eval() + logger.info(f"Policy loaded: {policy.name}") + + action_queue = ActionQueue(cfg.rtc) + + get_actions_t = Thread( + target=get_actions_thread, + args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active), + daemon=True, + name="GetActions", + ) + get_actions_t.start() + + listener = init_keyboard_listener(events) + + print("\n Controls:") + print(" SPACE/Right pedal - Pause policy") + print(" c/Left pedal - Take control") + print(" →/Right pedal - End episode (in correction mode)") + print(" ESC - Stop and push to hub") + print("=" * 65 + "\n") + + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + + if cfg.interpolation: + interp_factor = 2 + robot_interval = 1.0 / (cfg.fps * interp_factor) + else: + interp_factor = 1 + robot_interval = 1.0 / cfg.fps + + try: + recorded = 0 + while recorded < cfg.num_episodes and not events["stop_recording"]: + log_say(f"RaC episode {recorded + 1}", play_sounds=True) + + move_robot_to_zero(follower, duration_s=2.0) + + action_queue = ActionQueue(cfg.rtc) + events["policy_paused"] = False + events["correction_active"] = False + events["start_correction"] = False + events["exit_early"] = False + + frame_buffer = [] + prev_action: Tensor | None = None + interpolated_actions: list[Tensor] = [] + interp_idx = 0 + + robot_send_count = 0 + policy_consume_count = 0 + last_hz_print = time.perf_counter() + + policy_active.set() + episode_start = time.perf_counter() + + while (time.perf_counter() - episode_start) < cfg.episode_time_sec: + loop_start = time.perf_counter() + + if events["exit_early"]: + break + + if events["start_correction"] and not events["correction_active"]: + policy_active.clear() + print("[RaC] Moving teleop to robot position...") + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + teleop.disable_torque() + events["correction_active"] = True + events["start_correction"] = False + print("[RaC] Correction mode - you have control") + + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + obs_frame = build_dataset_frame(dataset_features, obs_filtered, prefix="observation") + + if events["correction_active"]: + robot_action = teleop.get_action() + for key in robot_action: + if "gripper" in key: + robot_action[key] = -0.65 * robot_action[key] + robot.send_action(robot_action) + + action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action") + frame = {**obs_frame, **action_frame, "task": cfg.task} + frame_buffer.append(frame) + + elif events["policy_paused"]: + pass + + else: + if interp_idx >= len(interpolated_actions): + new_action = action_queue.get() + if new_action is not None: + current_action = new_action.cpu() + policy_consume_count += 1 + + if cfg.interpolation and prev_action is not None: + mid = prev_action + 0.5 * (current_action - prev_action) + interpolated_actions = [mid, current_action] + else: + interpolated_actions = [current_action] + + prev_action = current_action + interp_idx = 0 + + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 + + action_dict = {} + for i, key in enumerate(action_keys): + if i < len(action_to_send): + action_dict[key] = action_to_send[i].item() + + action_processed = robot_proc((action_dict, None)) + robot.send_action(action_processed) + robot_send_count += 1 + + action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action") + frame = {**obs_frame, **action_frame, "task": cfg.task} + frame_buffer.append(frame) + + now = time.perf_counter() + if now - last_hz_print >= 5.0: + elapsed = now - last_hz_print + actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0 + actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0 + logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}") + robot_send_count = 0 + policy_consume_count = 0 + last_hz_print = now + + dt = time.perf_counter() - loop_start + sleep_time = max(0, robot_interval - dt - 0.001) + if sleep_time > 0: + precise_sleep(sleep_time) + + policy_active.clear() + + if events["rerecord_episode"]: + log_say("Re-recording", play_sounds=True) + events["rerecord_episode"] = False + continue + + for frame in frame_buffer: + dataset.add_frame(frame) + + with dataset_lock: + if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0: + logger.info(f"Saving episode ({dataset.episode_buffer['size']} frames)") + dataset.save_episode() + + recorded += 1 + + if recorded < cfg.num_episodes and not events["stop_recording"]: + events["in_reset"] = True + events["start_next_episode"] = False + print("\n[RaC] RESET - Press any key/pedal to enable teleop") + + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + + while not events["start_next_episode"] and not events["stop_recording"]: + precise_sleep(0.05) + + if events["stop_recording"]: + break + + events["start_next_episode"] = False + teleop.disable_torque() + print("[RaC] Teleop enabled - move to start, then press key/pedal") + + while not events["start_next_episode"] and not events["stop_recording"]: + action = teleop.get_action() + for key in action: + if "gripper" in key: + action[key] = -0.65 * action[key] + robot.send_action(action) + precise_sleep(1 / cfg.fps) + + events["in_reset"] = False + events["start_next_episode"] = False + + log_say("Recording complete", play_sounds=True, blocking=True) + + except KeyboardInterrupt: + logger.info("Interrupted by user") + + finally: + shutdown_event.set() + policy_active.clear() + + if get_actions_t.is_alive(): + get_actions_t.join(timeout=5.0) + + follower.disconnect() + teleop.disconnect() + + if listener: + listener.stop() + + dataset.finalize() + if cfg.push_to_hub: + logger.info("Pushing to hub...") + dataset.push_to_hub(private=True) + + logger.info("Done") + + +if __name__ == "__main__": + main() +