From 467981eaef6e91213de12ed557d1848e16259811 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 21 Jan 2026 16:39:53 +0100 Subject: [PATCH] Add changes from openarms experiments --- docs/source/_toctree.yml | 2 + docs/source/{rac.mdx => hil_collection.mdx} | 65 +- examples/rac/rac_data_collection.py | 303 ++++-- .../rac/rac_data_collection_openarms_rtc.py | 889 ++++++++++++++++++ 4 files changed, 1173 insertions(+), 86 deletions(-) rename docs/source/{rac.mdx => hil_collection.mdx} (85%) create mode 100644 examples/rac/rac_data_collection_openarms_rtc.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7766b3472..71d44c1a2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Train RL in Simulation - local: multi_gpu_training title: Multi GPU training + - local: hil_collection + title: Human In the Loop: Recovery and Correction Data Collection title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/rac.mdx b/docs/source/hil_collection.mdx similarity index 85% rename from docs/source/rac.mdx rename to docs/source/hil_collection.mdx index 022637ebf..b77891548 100644 --- a/docs/source/rac.mdx +++ b/docs/source/hil_collection.mdx @@ -1,13 +1,7 @@ -# RaC: Recovery and Correction Training +# Human In the Loop: Recovery and Correction Data Collection RaC (Recovery and Correction) is a human-in-the-loop data collection and training paradigm that improves robot policy performance on long-horizon tasks by explicitly teaching recovery and correction behaviors. -**Key References:** -- [RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction](https://arxiv.org/abs/2509.07953) (Hu et al., 2025) -- [HG-DAgger: Interactive Imitation Learning with Human Experts](https://arxiv.org/abs/1810.02890) (Kelly et al., 2019) -- [π∗0.6: a VLA That Learns From Experience](https://pi.website/blog/pistar06) (Physical Intelligence, 2025) -- [SARM: Stage-Aware Reward Modeling](https://arxiv.org/abs/2509.25358) (Chen et al., 2025) - --- ## Why RaC? The Problem with Standard Data Collection @@ -15,7 +9,7 @@ RaC (Recovery and Correction) is a human-in-the-loop data collection and trainin ### Standard Behavioral Cloning Data Collection Limitations Standard behavior cloning trains policies on successful demonstrations. This approach can be sensitive to distribution shift and compounding errors. Because during deployment small errors can cascade and push the robot into states never seen during training. -This is where RaC and methods like Dagger and HG-DAgger come in. +This is where RaC whick builds on work like Dagger and HG-DAgger comes in. ### Prior Human-in-the-Loop Methods @@ -38,7 +32,9 @@ BC/DAgger: policy → mistake → human corrects → continue RaC: policy → mistake → human RECOVERS (teleop back) → CORRECTS → END ``` -The critical insight is **Rule 1 (Recover then Correct)**: +THis Human in the loop approach follows two rules + +*Rule 1 (Recover then Correct)**: - Every intervention starts with human teleoperating back to an in-distribution state - Then human provides correction to complete the current subtask - Both segments are recorded as training data @@ -47,7 +43,6 @@ The critical insight is **Rule 1 (Recover then Correct)**: **Rule 2 (Terminate after Intervention)**: - Episode ends after correction completes - Avoids mixed policy/human data on later subtasks -- Keeps data distribution clean --- @@ -62,7 +57,7 @@ The critical insight is **Rule 1 (Recover then Correct)**: --- -## The RaC Pipeline +## The Pipeline ``` ┌─────────────────────────────────────────────────────────────────────────┐ @@ -122,23 +117,41 @@ python examples/rac/rac_data_collection.py \ --dataset.num_episodes=50 ``` -**Keyboard Controls:** +**Controls (Keyboard + Foot Pedal):** -| Key | Action | -|-----|--------| -| **SPACE** | Start intervention (take control) | -| **→** | End episode (save) | -| **ESC** | Stop recording session | +| Key / Pedal | Action | +|-------------|--------| +| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) | +| **c** / Left pedal | Take control (start correction, recording resumes) | +| **→** / Right pedal | End episode (save) - when in correction mode | +| **←** | Re-record episode | +| **ESC** | Stop session and push to hub | +| Any key/pedal during reset | Start next episode | **The RaC Protocol:** -1. Watch the policy run autonomously -2. When you see imminent failure, press **SPACE** to intervene -3. **RECOVERY**: Teleoperate the robot back to a good in-distribution state -4. **CORRECTION**: Use teleoperator to complete the subtask -5. Press **→** to save and end episode +1. Watch the policy run autonomously (teleop is idle/free) +2. When you see imminent failure, press **SPACE** or **right pedal** to pause + - Policy stops + - Teleoperator moves to match robot position (torque enabled) + - No frames recorded during pause +3. Press **c** or **left pedal** to take control + - Teleoperator torque disabled, free to move + - **RECOVERY**: Teleoperate back to a good state + - **CORRECTION**: Complete the subtask + - All movements are recorded +4. Press **→** or **right pedal** to save and end episode +5. **RESET**: Teleop moves to robot position, you can move robot to starting position +6. Press any key/pedal to start next episode -The recovery segment (teleoperating back to good state) is recorded as training data - this teaches the policy how to recover from errors. +The recovery and correction segments teach the policy how to recover from errors. + +**Foot Pedal Setup (Linux):** + +If using a USB foot pedal (PCsensor FootSwitch), ensure access: +```bash +sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd +``` ### Step 3: (Optional) Compute SARM Rewards @@ -233,11 +246,6 @@ RaC can be applied iteratively: └─────────────────────────────────────────────────────────────────────────┘ ``` -Each iteration: -1. Deploy current policy -2. Collect RaC interventions on failure cases -3. Fine-tune on accumulated data - --- ## References @@ -271,3 +279,4 @@ Each iteration: } ``` + diff --git a/examples/rac/rac_data_collection.py b/examples/rac/rac_data_collection.py index 2f4018481..62863b886 100644 --- a/examples/rac/rac_data_collection.py +++ b/examples/rac/rac_data_collection.py @@ -9,10 +9,10 @@ RaC improves upon standard data collection (BC) and prior human-in-the-loop meth (DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors: The workflow: -1. Policy runs autonomously until human presses SPACE to intervene -2. On intervention: human teleoperates the robot back to a good state (RECOVERY) -3. Human provides CORRECTION with teleoperator to complete the subtask -4. Press -> to end episode (save and continue to next) +1. Policy runs autonomously +2. Press SPACE to pause - robot holds position +3. Press 'c' to take control - human provides RECOVERY + CORRECTION +4. Press → to end episode (save and continue to next) 5. Reset, then do next rollout Key RaC Rules: @@ -23,9 +23,11 @@ The recovery segment (teleoperating back to good state) is recorded as training this teaches the policy how to recover from errors. Keyboard Controls: - SPACE - Start intervention (policy stops, human takes over) + SPACE - Pause policy (robot holds position, no recording) + c - Take control (start correction, recording resumes) → - End episode (save and continue to next) - ESC - Stop recording session + ← - Re-record episode + ESC - Stop recording and push dataset to hub Usage: python examples/rac/rac_data_collection.py \ @@ -129,7 +131,10 @@ def init_rac_keyboard_listener(): "exit_early": False, "rerecord_episode": False, "stop_recording": False, - "intervention_active": False, + "policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot + "correction_active": False, # 'c' pressed - human controlling, recording correction + "in_reset": False, # True during reset period + "start_next_episode": False, # Signal to start next episode } if is_headless(): @@ -140,32 +145,119 @@ def init_rac_keyboard_listener(): def on_press(key): try: - if key == keyboard.Key.space: - if not events["intervention_active"]: - print("\n[RaC] ▶ INTERVENTION - You have control") - print(" 1. Teleoperate robot back to good state (RECOVERY)") - print(" 2. Complete the subtask (CORRECTION)") - print(" 3. Press → when done") - events["intervention_active"] = True - elif key == keyboard.Key.right: - print("[RaC] → End episode") - events["exit_early"] = True - elif key == keyboard.Key.left: - print("[RaC] ← Re-record episode") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("[RaC] ESC - Stop recording session") - events["stop_recording"] = True - events["exit_early"] = True + if events["in_reset"]: + # During reset: any action key starts next episode + if key == keyboard.Key.space or key == keyboard.Key.right: + print("\n[RaC] Starting next episode...") + events["start_next_episode"] = True + elif hasattr(key, 'char') and key.char == 'c': + print("\n[RaC] Starting next episode...") + events["start_next_episode"] = True + elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") + events["stop_recording"] = True + events["start_next_episode"] = True + else: + # During episode + if key == keyboard.Key.space: + if not events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position") + print(" Press 'c' or START to take control") + events["policy_paused"] = True + elif hasattr(key, 'char') and key.char == 'c': + if events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ▶ START pressed - taking control") + events["start_next_episode"] = True + elif key == keyboard.Key.right: + print("[RaC] → End episode") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("[RaC] ← Re-record episode") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") + events["stop_recording"] = True + events["exit_early"] = True except Exception as e: print(f"Key error: {e}") listener = keyboard.Listener(on_press=on_press) listener.start() + + start_pedal_listener(events) + return listener, events +def start_pedal_listener(events: dict): + """Start foot pedal listener thread if evdev is available.""" + import threading + + try: + from evdev import InputDevice, ecodes + except ImportError: + logging.info("[Pedal] evdev not installed - pedal support disabled") + return + + PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + KEY_LEFT = "KEY_A" # Left pedal + KEY_RIGHT = "KEY_C" # Right pedal + + def pedal_reader(): + try: + dev = InputDevice(PEDAL_DEVICE) + print(f"[Pedal] Connected: {dev.name}") + print(f"[Pedal] Right=pause/next, Left=take control/start") + + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + + from evdev import categorize + key = categorize(ev) + code = key.keycode + if isinstance(code, (list, tuple)): + code = code[0] + + # Only trigger on key down + if key.keystate != 1: + continue + + if events["in_reset"]: + # During reset: either pedal starts next episode + if code in [KEY_LEFT, KEY_RIGHT]: + print("\n[Pedal] Starting next episode...") + events["start_next_episode"] = True + else: + # During episode + if code == KEY_RIGHT: + # Right pedal: SPACE (pause) when running, → (next) when in correction + if events["correction_active"]: + print("\n[Pedal] → End episode") + events["exit_early"] = True + elif not events["policy_paused"]: + print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot") + print(" Press left pedal to take control") + events["policy_paused"] = True + + elif code == KEY_LEFT: + # Left pedal: START (take control) when paused + if events["policy_paused"] and not events["correction_active"]: + print("\n[Pedal] ▶ START pressed - taking control") + events["start_next_episode"] = True + + except FileNotFoundError: + logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") + except PermissionError: + logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}") + except Exception as e: + logging.debug(f"[Pedal] Error: {e}") + + thread = threading.Thread(target=pedal_reader, daemon=True) + thread.start() + + def make_identity_processors(): """Create identity processors for RaC recording.""" teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( @@ -186,6 +278,21 @@ def make_identity_processors(): return teleop_proc, robot_proc, obs_proc +def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50): + """Smoothly move all robot joints to zero position.""" + obs = robot.get_observation() + current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} + target_pos = {k: 0.0 for k in current_pos} + + print(f"[RaC] Moving robot to zero position ({duration_s}s)...") + steps = int(duration_s * fps) + for step in range(steps + 1): + t = step / steps + interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos} + robot.send_action(interp_pos) + time.sleep(1 / fps) + print("[RaC] Robot at zero position.") + @safe_stop_image_writer def rac_rollout_loop( robot: Robot, @@ -201,10 +308,12 @@ def rac_rollout_loop( display_data: bool = True, ) -> dict: """ - RaC rollout loop: policy runs until intervention, then human does recovery+correction. - - The human intervention (recovery + correction) is recorded as training data. - This teaches the policy how to recover from errors. + RaC rollout loop with two-stage intervention: + + 1. Policy runs autonomously (recording) + 2. SPACE: Policy pauses (NOT recording) - robot holds position + 3. 'c': Human takes control (recording correction) + 4. →: End episode """ policy.reset() preprocessor.reset() @@ -216,10 +325,14 @@ def rac_rollout_loop( stats = { "total_frames": 0, "autonomous_frames": 0, - "human_frames": 0, - "intervention_occurred": False, + "paused_frames": 0, + "correction_frames": 0, } + last_robot_action = None + was_paused = False + was_correction_active = False + waiting_for_takeover = False timestamp = 0 start_t = time.perf_counter() @@ -228,13 +341,59 @@ def rac_rollout_loop( if events["exit_early"]: events["exit_early"] = False - events["intervention_active"] = False + events["policy_paused"] = False + events["correction_active"] = False break + # Detect transition to paused state + if events["policy_paused"] and not was_paused: + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} + print("[RaC] Moving teleop to robot position (2s smooth transition)...") + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + print("[RaC] Teleop aligned. Press START to take control.") + events["start_next_episode"] = False + waiting_for_takeover = True + was_paused = True + + # Wait for start button before enabling correction mode + if waiting_for_takeover and events["start_next_episode"]: + print("[RaC] Start pressed - enabling teleop control...") + events["start_next_episode"] = False + events["correction_active"] = True + waiting_for_takeover = False + was_correction_active = True + obs = robot.get_observation() obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR) - if not events["intervention_active"]: + if events["correction_active"]: + # Human controlling - record correction data + robot_action = teleop.get_action() + robot.send_action(robot_action) + stats["correction_frames"] += 1 + + # Record this frame + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 + + elif waiting_for_takeover: + # Waiting for START - policy stopped, no recording, robot holds position + if last_robot_action is not None: + robot.send_action(last_robot_action) + stats["paused_frames"] += 1 + + elif events["policy_paused"]: + # Paused and user acknowledged - hold last position, don't record + if last_robot_action is not None: + robot.send_action(last_robot_action) + stats["paused_frames"] += 1 + robot_action = last_robot_action + + else: + # Normal policy execution - record action_values = predict_action( observation=obs_frame, policy=policy, @@ -246,22 +405,18 @@ def rac_rollout_loop( robot_type=robot.robot_type, ) robot_action: RobotAction = make_robot_action(action_values, dataset.features) + robot.send_action(robot_action) + last_robot_action = robot_action stats["autonomous_frames"] += 1 - else: - stats["intervention_occurred"] = True - robot_action = teleop.get_action() - action_values = robot_action - stats["human_frames"] += 1 + + # Record this frame + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 - robot.send_action(robot_action) - - action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - - if display_data: - log_rerun_data(observation=obs, action=action_values) + if display_data and robot_action is not None: + log_rerun_data(observation=obs, action=robot_action) dt = time.perf_counter() - loop_start precise_sleep(1 / fps - dt) @@ -278,15 +433,37 @@ def reset_loop( teleop: Teleoperator, events: dict, fps: int, - reset_time_s: float, ): - """Reset period where human repositions environment.""" - print(f"\n[RaC] Reset time: {reset_time_s}s - reposition environment") + """Reset period where human repositions environment. Two-stage: enable teleop, then start episode.""" + print("\n" + "=" * 65) + print(" [RaC] RESET - Moving teleop to robot position...") + print("=" * 65) + + # Enter reset mode + events["in_reset"] = True + events["start_next_episode"] = False + + # Move teleop to match robot position to avoid sudden jumps + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + + # Stage 1: Wait for user to press start to enable teleoperation + print(" Teleop aligned. Press any key/pedal to enable teleoperation") + while not events["start_next_episode"] and not events["stop_recording"]: + precise_sleep(0.05) + + if events["stop_recording"]: + return + + # Stage 2: Enable teleop and let user move robot to starting position + events["start_next_episode"] = False + teleop.disable_torque() + print(" Teleop enabled - move robot to starting position") + print(" Press any key/pedal to start next episode") - timestamp = 0 - start_t = time.perf_counter() - - while timestamp < reset_time_s and not events["exit_early"]: + # Wait for user to signal ready for next episode + while not events["start_next_episode"] and not events["stop_recording"]: loop_start = time.perf_counter() action = teleop.get_action() @@ -294,7 +471,13 @@ def reset_loop( dt = time.perf_counter() - loop_start precise_sleep(1 / fps - dt) - timestamp = time.perf_counter() - start_t + + # Exit reset mode and clear flags for next episode + events["in_reset"] = False + events["start_next_episode"] = False + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False @parser.wrap() @@ -374,15 +557,19 @@ def rac_collect(cfg: RaCConfig) -> LeRobotDataset: print(" Policy runs autonomously until you intervene.") print() print(" Controls:") - print(" SPACE - Intervene (take control)") + print(" SPACE - Pause policy (robot holds position, no recording)") + print(" c - Take control (start correction, recording)") print(" → - End episode (save)") - print(" ESC - Stop recording session") + print(" ← - Re-record episode") + print(" ESC - Stop session and push to hub") print("=" * 65 + "\n") with VideoEncodingManager(dataset): recorded = 0 while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) + + move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps) stats = rac_rollout_loop( robot=robot, @@ -417,7 +604,6 @@ def rac_collect(cfg: RaCConfig) -> LeRobotDataset: teleop=teleop, events=events, fps=cfg.dataset.fps, - reset_time_s=cfg.dataset.reset_time_s, ) finally: @@ -450,3 +636,4 @@ def main(): if __name__ == "__main__": main() + diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py new file mode 100644 index 000000000..f89248ff8 --- /dev/null +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -0,0 +1,889 @@ +#!/usr/bin/env python +""" +RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC. + +This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution. +RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion +despite high inference latency by asynchronously generating action chunks. + +The workflow: +1. Policy runs autonomously with RTC (teleop is idle/free) +2. Press SPACE to pause - teleop moves to match robot position +3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION +4. Press → to end episode (save and continue to next) +5. Reset, then do next rollout + +Usage: + python examples/rac/rac_data_collection_openarms_rtc.py \ + --robot.port_right=can0 \ + --robot.port_left=can1 \ + --teleop.port_right=/dev/ttyUSB0 \ + --teleop.port_left=/dev/ttyUSB1 \ + --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ + --dataset.repo_id=my_user/rac_openarms_dataset \ + --dataset.single_task="Pick up the cube" +""" + +import logging +import math +import time +from dataclasses import dataclass, field +from pathlib import Path +from pprint import pformat +from threading import Event, Lock, Thread +from typing import Any + +import torch +from torch import Tensor + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.image_writer import safe_stop_image_writer +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features +from lerobot.datasets.video_utils import VideoEncodingManager +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.policies.utils import make_robot_action +from lerobot.processor import ( + IdentityProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RobotAction, + RobotObservation, + RobotProcessorPipeline, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.processor.rename_processor import rename_stats +from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401 +from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config +from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401 +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.control_utils import is_headless, predict_action +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Configuration +# ============================================================================ + +@dataclass +class RaCRTCDatasetConfig: + repo_id: str = "lerobot/rac_openarms_rtc" + single_task: str = "default task" + root: str | Path | None = None + fps: int = 30 + episode_time_s: float = 500 + reset_time_s: float = 30 + num_episodes: int = 50 + video: bool = True + push_to_hub: bool = True + private: bool = False + tags: list[str] | None = None + num_image_writer_processes: int = 0 + num_image_writer_threads_per_camera: int = 4 + video_encoding_batch_size: int = 1 + rename_map: dict[str, str] = field(default_factory=dict) + + +@dataclass +class RaCRTCConfig: + robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig( + port_left="can0", + port_right="can1", + )) + teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig( + port_left="/dev/ttyUSB1", + port_right="/dev/ttyUSB0", + )) + dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig) + policy: PreTrainedConfig | None = None + + rtc: RTCConfig = field(default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=20, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, + )) + + interpolation: bool = True + display_data: bool = True + play_sounds: bool = True + resume: bool = False + device: str = "cuda" + action_queue_size_to_get_new_actions: int = 30 + + # Torch compile is disabled by default for real-time inference + # First inference with compile takes minutes to compile kernels + use_torch_compile: bool = False + + def __post_init__(self): + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("policy.path is required") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + return ["policy"] + + +# ============================================================================ +# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py) +# ============================================================================ + +class RobotWrapper: + """Thread-safe wrapper for robot operations.""" + + def __init__(self, robot: Robot): + self.robot = robot + self.lock = Lock() + + def get_observation(self) -> dict[str, Tensor]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: dict) -> None: + with self.lock: + self.robot.send_action(action) + + @property + def observation_features(self) -> dict: + return self.robot.observation_features + + @property + def action_features(self) -> dict: + return self.robot.action_features + + @property + def name(self) -> str: + return self.robot.name + + @property + def robot_type(self) -> str: + return self.robot.robot_type + + +# ============================================================================ +# Keyboard/Pedal Listeners +# ============================================================================ + +def init_rac_keyboard_listener(): + """Initialize keyboard listener with RaC-specific controls.""" + events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + "policy_paused": False, + "correction_active": False, + "in_reset": False, + "start_next_episode": False, + } + + if is_headless(): + logging.warning("Headless environment - keyboard controls unavailable") + return None, events + + from pynput import keyboard + + def on_press(key): + try: + if events["in_reset"]: + if key == keyboard.Key.space or key == keyboard.Key.right: + print("\n[RaC] Starting next episode...") + events["start_next_episode"] = True + elif hasattr(key, 'char') and key.char == 'c': + print("\n[RaC] Starting next episode...") + events["start_next_episode"] = True + elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") + events["stop_recording"] = True + events["start_next_episode"] = True + else: + if key == keyboard.Key.space: + if not events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position") + print(" Press 'c' or START to take control") + events["policy_paused"] = True + elif hasattr(key, 'char') and key.char == 'c': + if events["policy_paused"] and not events["correction_active"]: + print("\n[RaC] ▶ START pressed - taking control") + events["start_next_episode"] = True + elif key == keyboard.Key.right: + print("[RaC] → End episode") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("[RaC] ← Re-record episode") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Key error: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + start_pedal_listener(events) + + return listener, events + + +def start_pedal_listener(events: dict): + """Start foot pedal listener thread if evdev is available.""" + import threading + + try: + from evdev import InputDevice, ecodes # noqa: F401 + except ImportError: + logging.info("[Pedal] evdev not installed - pedal support disabled") + return + + PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + KEY_LEFT = "KEY_A" + KEY_RIGHT = "KEY_C" + + def pedal_reader(): + try: + dev = InputDevice(PEDAL_DEVICE) + print(f"[Pedal] Connected: {dev.name}") + + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + + from evdev import categorize # noqa: F401 + key = categorize(ev) + code = key.keycode + if isinstance(code, (list, tuple)): + code = code[0] + + if key.keystate != 1: + continue + + if events["in_reset"]: + if code in [KEY_LEFT, KEY_RIGHT]: + events["start_next_episode"] = True + else: + if code == KEY_RIGHT: + if events["correction_active"]: + events["exit_early"] = True + elif not events["policy_paused"]: + events["policy_paused"] = True + elif code == KEY_LEFT: + if events["policy_paused"] and not events["correction_active"]: + events["start_next_episode"] = True + + except FileNotFoundError: + logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") + except PermissionError: + logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}") + except Exception as e: + logging.debug(f"[Pedal] Error: {e}") + + thread = threading.Thread(target=pedal_reader, daemon=True) + thread.start() + + +def make_identity_processors(): + """Create identity processors for RaC recording.""" + teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[IdentityProcessorStep()], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + return teleop_proc, robot_proc, obs_proc + + +# ============================================================================ +# RTC Inference Thread (from evaluate_with_rtc.py) +# ============================================================================ + +def rtc_inference_thread( + policy, + obs_holder: dict, + hw_features: dict, + preprocessor, + postprocessor, + queue_holder: dict, + shutdown_event: Event, + policy_active: Event, + cfg: RaCRTCConfig, +): + """Background thread that generates action chunks using RTC.""" + try: + logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========") + logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys") + + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.dataset.fps + policy_device = policy.config.device + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + inference_count = 0 + wait_logged = False + + while not shutdown_event.is_set(): + if not policy_active.is_set(): + if not wait_logged: + logger.info("[RTC] Waiting for policy_active...") + wait_logged = True + time.sleep(0.01) + continue + + wait_logged = False + + action_queue = queue_holder["queue"] + if action_queue is None: + logger.warning("[RTC] queue_holder['queue'] is None!") + time.sleep(0.01) + continue + + obs_filtered = obs_holder.get("obs") + if obs_filtered is None: + logger.warning("[RTC] obs_holder['obs'] is None!") + time.sleep(0.01) + continue + + qsize = action_queue.qsize() + if qsize <= get_actions_threshold: + try: + if inference_count == 0: + logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}") + + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255 + obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous() + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device) + + obs_with_policy_features["task"] = [cfg.dataset.single_task] + obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") + + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference) + + inference_count += 1 + logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}") + except Exception as e: + logger.error(f"[RTC] Inference error: {e}") + import traceback + traceback.print_exc() + time.sleep(1.0) + else: + time.sleep(0.01) + + logger.info("[RTC] Inference thread shutting down") + except Exception as e: + logger.error(f"[RTC] THREAD CRASHED: {e}") + import traceback + traceback.print_exc() + + +# ============================================================================ +# Main Rollout Loop +# ============================================================================ + +@safe_stop_image_writer +def rac_rtc_rollout_loop( + robot: RobotWrapper, + teleop: Teleoperator, + policy: PreTrainedPolicy, + preprocessor, + postprocessor, + dataset: LeRobotDataset, + events: dict, + cfg: RaCRTCConfig, + queue_holder: dict, + obs_holder: dict, # Main loop writes obs here for RTC thread to read + policy_active: Event, + hw_features: dict, +) -> dict: + """RaC rollout loop with RTC for smooth policy execution.""" + fps = cfg.dataset.fps + single_task = cfg.dataset.single_task + control_time_s = cfg.dataset.episode_time_s + device = get_safe_torch_device(cfg.device) + + # Reset policy state + policy.reset() + preprocessor.reset() + postprocessor.reset() + + frame_buffer = [] + stats = { + "total_frames": 0, + "autonomous_frames": 0, + "paused_frames": 0, + "correction_frames": 0, + } + + teleop.disable_torque() + was_paused = False + waiting_for_takeover = False + + # Action keys for converting tensor to dict + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + + # Interpolation state + prev_action: Tensor | None = None + interpolated_actions: list[Tensor] = [] + interp_idx = 0 + + if cfg.interpolation: + control_interval = 1.0 / (fps * 2) # 2x rate + else: + control_interval = 1.0 / fps + + robot_action = {} + timestamp = 0 + start_t = time.perf_counter() + + while timestamp < control_time_s: + loop_start = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + break + + # State transition: entering paused state + if events["policy_paused"] and not was_paused: + policy_active.clear() # Stop RTC inference + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} + print("[RaC] Moving teleop to robot position...") + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + print("[RaC] Teleop aligned. Press 'c' to take control.") + events["start_next_episode"] = False + waiting_for_takeover = True + was_paused = True + # Reset interpolation + prev_action = None + interpolated_actions = [] + interp_idx = 0 + + # Wait for takeover + if waiting_for_takeover and events["start_next_episode"]: + print("[RaC] Taking control...") + teleop.disable_torque() + events["start_next_episode"] = False + events["correction_active"] = True + waiting_for_takeover = False + + # Get observation (ONLY the main loop reads from robot!) + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) + + # Share observation with RTC thread (thread reads, main loop writes) + obs_holder["obs"] = obs_filtered + + if events["correction_active"]: + # Human controlling + robot_action = teleop.get_action() + for key in robot_action: + if "gripper" in key: + robot_action[key] = -0.65 * robot_action[key] + robot.send_action(robot_action) + stats["correction_frames"] += 1 + + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 + + elif waiting_for_takeover: + stats["paused_frames"] += 1 + + elif events["policy_paused"]: + robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} + teleop.send_feedback(robot_pos) + stats["paused_frames"] += 1 + + else: + # Policy execution with RTC + if not policy_active.is_set(): + policy_active.set() + logger.info("[ROLLOUT] Policy activated, waiting for first actions...") + + action_queue = queue_holder["queue"] + + # Get action from queue (with interpolation) + if interp_idx >= len(interpolated_actions): + new_action = action_queue.get() if action_queue else None + + # Log queue status periodically + if stats["autonomous_frames"] == 0 and new_action is None: + qsize = action_queue.qsize() if action_queue else -1 + if timestamp < 0.5 or int(timestamp * 10) % 10 == 0: + logger.info(f"[ROLLOUT] Waiting for actions... queue_size={qsize}, obs_set={obs_holder.get('obs') is not None}") + + if new_action is not None: + current_action = new_action.cpu() + + if cfg.interpolation and prev_action is not None: + mid = prev_action + 0.5 * (current_action - prev_action) + interpolated_actions = [mid, current_action] + else: + interpolated_actions = [current_action] + + prev_action = current_action + interp_idx = 0 + + if stats["autonomous_frames"] == 0: + logger.info(f"[ROLLOUT] Got first action! Starting robot motion.") + + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 + + robot_action = {} + for i, key in enumerate(action_keys): + if i < len(action_to_send): + robot_action[key] = action_to_send[i].item() + + robot.send_action(robot_action) + stats["autonomous_frames"] += 1 + + # Record at original fps + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 + + if cfg.display_data: + log_rerun_data(observation=obs_filtered, action=robot_action) + + dt = time.perf_counter() - loop_start + sleep_time = control_interval - dt + if sleep_time > 0: + precise_sleep(sleep_time) + timestamp = time.perf_counter() - start_t + + policy_active.clear() + teleop.disable_torque() + + for frame in frame_buffer: + dataset.add_frame(frame) + + return stats + + +def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int): + """Reset period where human repositions environment.""" + print("\n" + "=" * 65) + print(" [RaC] RESET") + print("=" * 65) + + events["in_reset"] = True + events["start_next_episode"] = False + + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + + print(" Press any key/pedal to enable teleoperation") + while not events["start_next_episode"] and not events["stop_recording"]: + precise_sleep(0.05) + + if events["stop_recording"]: + return + + events["start_next_episode"] = False + teleop.disable_torque() + print(" Teleop enabled - press any key/pedal to start episode") + + while not events["start_next_episode"] and not events["stop_recording"]: + loop_start = time.perf_counter() + action = teleop.get_action() + for key in action: + if "gripper" in key: + action[key] = -0.65 * action[key] + robot.send_action(action) + dt = time.perf_counter() - loop_start + precise_sleep(1 / fps - dt) + + events["in_reset"] = False + events["start_next_episode"] = False + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +@parser.wrap() +def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: + """Main RaC data collection function with RTC.""" + init_logging() + logging.info(pformat(cfg.__dict__)) + + if cfg.display_data: + init_rerun(session_name="rac_rtc_collection_openarms") + + robot_raw = make_robot_from_config(cfg.robot) + teleop = make_teleoperator_from_config(cfg.teleop) + + teleop_proc, robot_proc, obs_proc = make_identity_processors() + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_proc, + initial_features=create_initial_features(action=robot_raw.action_features), + use_videos=cfg.dataset.video, + ), + aggregate_pipeline_dataset_features( + pipeline=obs_proc, + initial_features=create_initial_features(observation=robot_raw.observation_features), + use_videos=cfg.dataset.video, + ), + ) + + dataset = None + listener = None + shutdown_event = Event() + policy_active = Event() + rtc_thread = None + + try: + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + if hasattr(robot_raw, "cameras") and robot_raw.cameras: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), + ) + else: + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot_raw.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + + # Load policy + logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") + policy_class = get_policy_class(cfg.policy.type) + + # Override compile_model for real-time inference (first compile takes minutes) + policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + if cfg.policy.type in ["pi05", "pi0"]: + policy_config.compile_model = cfg.use_torch_compile + logger.info(f"Set compile_model={cfg.use_torch_compile} for real-time inference") + + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config) + policy.config.rtc_config = cfg.rtc + policy.init_rtc_processor() + policy = policy.to(cfg.device) + policy.eval() + logger.info(f"Policy loaded: {policy.name}") + + # Setup preprocessor/postprocessor + hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation") + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) + + # Connect robot and wrap for thread safety + robot_raw.connect() + robot = RobotWrapper(robot_raw) + + teleop.connect() + listener, events = init_rac_keyboard_listener() + + # Shared state holders (main loop writes, RTC thread reads) + queue_holder = {"queue": ActionQueue(cfg.rtc)} + obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs + + # Start RTC inference thread + # NOTE: Thread does NOT access robot directly - reads from obs_holder + rtc_thread = Thread( + target=rtc_inference_thread, + args=( + policy, + obs_holder, # Thread reads obs from here (set by main loop) + hw_features, + preprocessor, + postprocessor, + queue_holder, + shutdown_event, + policy_active, + cfg, + ), + daemon=True, + name="RTCInference", + ) + rtc_thread.start() + logger.info("Started RTC inference thread") + + print("\n" + "=" * 65) + print(" RaC Data Collection with RTC") + print("=" * 65) + print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Task: {cfg.dataset.single_task}") + print(f" FPS: {cfg.dataset.fps}") + print(f" Interpolation: {cfg.interpolation}") + print() + print(" Controls:") + print(" SPACE - Pause policy") + print(" c - Take control") + print(" → - End episode") + print(" ESC - Stop and push to hub") + print("=" * 65 + "\n") + + with VideoEncodingManager(dataset): + recorded = 0 + while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) + + # Fresh action queue per episode (update holder so thread sees it) + queue_holder["queue"] = ActionQueue(cfg.rtc) + + logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") + + stats = rac_rtc_rollout_loop( + robot=robot, + teleop=teleop, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + events=events, + cfg=cfg, + queue_holder=queue_holder, + obs_holder=obs_holder, + policy_active=policy_active, + hw_features=hw_features, + ) + + logging.info(f"Episode stats: {stats}") + + if events["rerecord_episode"]: + log_say("Re-recording", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded += 1 + + if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + reset_loop(robot, teleop, events, cfg.dataset.fps) + + finally: + log_say("Stop recording", cfg.play_sounds, blocking=True) + + shutdown_event.set() + policy_active.clear() + + if rtc_thread and rtc_thread.is_alive(): + rtc_thread.join(timeout=2.0) + + if dataset: + dataset.finalize() + + if robot_raw.is_connected: + robot_raw.disconnect() + if teleop.is_connected: + teleop.disconnect() + + if not is_headless() and listener: + listener.stop() + + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + return dataset + + +def main(): + from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() + rac_rtc_collect() + + +if __name__ == "__main__": + main() +