From 5ac8fe32435e158f095d8a771f94a8773c2265f0 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 23 Jan 2026 12:47:54 +0100 Subject: [PATCH] Rename to hil, use two seperta scripts one rtc one synch --- docs/source/_toctree.yml | 2 +- docs/source/hil_collection.mdx | 122 ++- examples/rac/hil_data_collection.py | 349 +++++++ examples/rac/hil_data_collection_rtc.py | 501 ++++++++++ examples/rac/hil_utils.py | 251 +++++ examples/rac/rac_data_collection.py | 639 ------------- .../rac/rac_data_collection_openarms_rtc.py | 889 ------------------ examples/rtc/eval_with_real_robot.py | 17 +- src/lerobot/policies/rtc/__init__.py | 30 + .../policies/rtc/action_interpolator.py | 118 +++ src/lerobot/scripts/lerobot_record.py | 89 +- 11 files changed, 1414 insertions(+), 1593 deletions(-) create mode 100644 examples/rac/hil_data_collection.py create mode 100644 examples/rac/hil_data_collection_rtc.py create mode 100644 examples/rac/hil_utils.py delete mode 100644 examples/rac/rac_data_collection.py delete mode 100644 examples/rac/rac_data_collection_openarms_rtc.py create mode 100644 src/lerobot/policies/rtc/__init__.py create mode 100644 src/lerobot/policies/rtc/action_interpolator.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 71d44c1a2..a56f261ad 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -20,7 +20,7 @@ - local: multi_gpu_training title: Multi GPU training - local: hil_collection - title: Human In the Loop: Recovery and Correction Data Collection + title: Human In the Loop Data Collection title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/hil_collection.mdx b/docs/source/hil_collection.mdx index b77891548..9150c3d8a 100644 --- a/docs/source/hil_collection.mdx +++ b/docs/source/hil_collection.mdx @@ -9,16 +9,18 @@ RaC (Recovery and Correction) is a human-in-the-loop data collection and trainin ### Standard Behavioral Cloning Data Collection Limitations Standard behavior cloning trains policies on successful demonstrations. This approach can be sensitive to distribution shift and compounding errors. Because during deployment small errors can cascade and push the robot into states never seen during training. -This is where RaC whick builds on work like Dagger and HG-DAgger comes in. +This is where RaC which builds on work like Dagger and HG-DAgger comes in. ### Prior Human-in-the-Loop Methods **DAgger** (Dataset Aggregation) addresses distribution shift by: + - Running the novice policy to collect states - Querying expert for correct actions at those states - Aggregating new labels into training set **HG-DAgger** (Human-Gated DAgger) improves on DAgger by: + - Giving human full control authority during interventions - Human takes over when unsafe, provides correction, returns control - Better action labels because human has uninterrupted control @@ -32,15 +34,17 @@ BC/DAgger: policy → mistake → human corrects → continue RaC: policy → mistake → human RECOVERS (teleop back) → CORRECTS → END ``` -THis Human in the loop approach follows two rules +This Human in the loop approach follows two rules: + +**Rule 1 (Recover then Correct)**: -*Rule 1 (Recover then Correct)**: - Every intervention starts with human teleoperating back to an in-distribution state - Then human provides correction to complete the current subtask - Both segments are recorded as training data - This teaches the policy: "when things go wrong, go back and retry" **Rule 2 (Terminate after Intervention)**: + - Episode ends after correction completes - Avoids mixed policy/human data on later subtasks @@ -48,12 +52,39 @@ THis Human in the loop approach follows two rules ## Comparison Table -| Method | Data Type | Recovery Behavior | Correction Behavior | -|--------|-----------|-------------------|---------------------| -| BC | Success only | ✗ | ✗ | -| DAgger | Success + corrections | ✗ | ✓ | -| HG-DAgger | Success + corrections | Sometimes | ✓ | -| RaC | Success + recovery + correction | ✓ Explicit | ✓ | +| Method | Data Type | Recovery Behavior | Correction Behavior | +| --------- | ------------------------------- | ----------------- | ------------------- | +| BC | Success only | ✗ | ✗ | +| DAgger | Success + corrections | ✗ | ✓ | +| HG-DAgger | Success + corrections | Sometimes | ✓ | +| RaC | Success + recovery + correction | ✓ Explicit | ✓ | + +--- + +## Hardware Requirements + +### Teleoperator Requirements + +The HIL data collection script requires **teleoperators with active motors** that can: + +- Enable/disable torque programmatically +- Move to target positions (to mirror robot state when pausing) + +**Compatible teleoperators:** + +- `so101_leader` - SO-101 Leader Arm +- `openarms_mini` - OpenArms Mini (via third-party plugin) + +--- + +## Scripts + +Two scripts are provided depending on your policy's inference speed: + +| Script | Use Case | Models | +| ---------------------------- | ------------------------------------------ | --------------------- | +| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy | +| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA | --- @@ -67,7 +98,7 @@ THis Human in the loop approach follows two rules │ 1. PRE-TRAINING (Standard BC) │ │ └─> Train initial policy on clean demonstrations │ │ │ -│ 2. RAC DATA COLLECTION (Human-in-the-loop) │ +│ 2. HIL DATA COLLECTION (Human-in-the-loop) │ │ ├─> Policy runs autonomously │ │ ├─> Human monitors and intervenes when failure imminent │ │ │ ├─> RECOVERY: Human teleoperates robot back to good state │ @@ -78,7 +109,7 @@ THis Human in the loop approach follows two rules │ └─> Compute progress rewards for advantage-weighted training │ │ │ │ 4. FINE-TUNING │ -│ └─> Train on combined demos + RaC data (optionally with RA-BC) │ +│ └─> Train on combined demos + HIL data (optionally with RA-BC) │ │ │ └─────────────────────────────────────────────────────────────────────────┘ ``` @@ -100,35 +131,50 @@ python src/lerobot/scripts/lerobot_train.py \ --steps=50000 ``` -### Step 2: Collect RaC Data +### Step 2: Collect HIL Data -Run the RaC data collection script with your pre-trained policy: +**Standard inference (ACT, Diffusion Policy):** ```bash -python examples/rac/rac_data_collection.py \ +python examples/rac/hil_data_collection.py \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ - --dataset.repo_id=your-username/rac-dataset \ + --dataset.repo_id=your-username/hil-dataset \ --dataset.single_task="Pick up the cube and place it in the bowl" \ --dataset.num_episodes=50 ``` +**With RTC for large models (Pi0, Pi0.5, SmolVLA):** + +For models with high inference latency, use the RTC script for smooth execution: + +```bash +python examples/rac/hil_data_collection_rtc.py \ + --robot.type=so100_follower \ + --teleop.type=so100_leader \ + --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ + --dataset.repo_id=your-username/hil-rtc-dataset \ + --dataset.single_task="Pick up the cube" \ + --rtc.execution_horizon=20 \ + --interpolation=true +``` + **Controls (Keyboard + Foot Pedal):** -| Key / Pedal | Action | -|-------------|--------| -| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) | -| **c** / Left pedal | Take control (start correction, recording resumes) | -| **→** / Right pedal | End episode (save) - when in correction mode | -| **←** | Re-record episode | -| **ESC** | Stop session and push to hub | -| Any key/pedal during reset | Start next episode | +| Key / Pedal | Action | +| -------------------------- | -------------------------------------------------- | +| **SPACE** / Right pedal | Pause policy (teleop mirrors robot, no recording) | +| **c** / Left pedal | Take control (start correction, recording resumes) | +| **→** / Right pedal | End episode (save) - when in correction mode | +| **←** | Re-record episode | +| **ESC** | Stop session and push to hub | +| Any key/pedal during reset | Start next episode | -**The RaC Protocol:** +**The HIL Protocol:** 1. Watch the policy run autonomously (teleop is idle/free) 2. When you see imminent failure, press **SPACE** or **right pedal** to pause @@ -149,6 +195,7 @@ The recovery and correction segments teach the policy how to recover from errors **Foot Pedal Setup (Linux):** If using a USB foot pedal (PCsensor FootSwitch), ensure access: + ```bash sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd ``` @@ -159,7 +206,7 @@ For advantage-weighted training (RA-BC / Pi0.6-style), compute SARM progress val ```bash python src/lerobot/policies/sarm/compute_rabc_weights.py \ - --dataset-repo-id your-username/rac-dataset \ + --dataset-repo-id your-username/hil-dataset \ --reward-model-path your-username/sarm-model \ --head-mode sparse \ --push-to-hub @@ -167,23 +214,23 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \ ### Step 4: Fine-tune Policy -Fine-tune on the RaC data: +Fine-tune on the HIL data: ```bash # Without RA-BC (standard fine-tuning) python src/lerobot/scripts/lerobot_train.py \ - --dataset.repo_id=your-username/rac-dataset \ + --dataset.repo_id=your-username/hil-dataset \ --policy.type=pi0 \ --policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \ - --output_dir=outputs/rac_finetune \ + --output_dir=outputs/hil_finetune \ --steps=20000 # With RA-BC (advantage-weighted, Pi0.6-style) python src/lerobot/scripts/lerobot_train.py \ - --dataset.repo_id=your-username/rac-dataset \ + --dataset.repo_id=your-username/hil-dataset \ --policy.type=pi0 \ --policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \ - --output_dir=outputs/rac_finetune_rabc \ + --output_dir=outputs/hil_finetune_rabc \ --use_rabc=true \ --rabc_kappa=0.01 \ --steps=20000 @@ -194,22 +241,25 @@ python src/lerobot/scripts/lerobot_train.py \ ## Connection to Pi0.6 / RECAP Pi0.6's RECAP method shares similar principles: + - Collect autonomous rollouts + expert interventions - Use value function to compute **advantages**: A(s,a) = V(s') - V(s) - **Advantage conditioning**: Weight training based on expected improvement In LeRobot, we can use **SARM** as the value function: + - SARM progress φ(s) ∈ [0,1] measures task completion - Progress delta = φ(s') - φ(s) approximates advantage - RA-BC uses these to weight training samples (higher weight for good corrections) --- -## Tips for Effective RaC Collection +## Tips for Effective HIL Collection ### When to Intervene Intervene when you see: + - Robot about to make an irreversible mistake - Robot hesitating or showing uncertain behavior - Robot deviating from expected trajectory @@ -217,6 +267,7 @@ Intervene when you see: ### Recovery: Teleoperating Back to Good State During recovery, teleoperate the robot back to a state where: + - The robot is in a familiar, in-distribution configuration - The current subtask can still be completed - The recovery trajectory itself is informative training data @@ -224,6 +275,7 @@ During recovery, teleoperate the robot back to a state where: ### Quality of Corrections During correction: + - Provide **confident, clean** trajectories - Complete the current subtask fully - Don't overcorrect or add unnecessary movements @@ -232,15 +284,15 @@ During correction: ## Iterative Improvement -RaC can be applied iteratively: +HIL data collection can be applied iteratively: ``` ┌─────────────────────────────────────────────────────────────────────────┐ │ Policy v0 (demos) │ │ ↓ │ -│ RaC Collection (target current failure modes) → Policy v1 │ +│ HIL Collection (target current failure modes) → Policy v1 │ │ ↓ │ -│ RaC Collection (target new failure modes) → Policy v2 │ +│ HIL Collection (target new failure modes) → Policy v2 │ │ ↓ │ │ ... (repeat until satisfactory performance) │ └─────────────────────────────────────────────────────────────────────────┘ @@ -278,5 +330,3 @@ RaC can be applied iteratively: year={2025} } ``` - - diff --git a/examples/rac/hil_data_collection.py b/examples/rac/hil_data_collection.py new file mode 100644 index 000000000..4555eafc4 --- /dev/null +++ b/examples/rac/hil_data_collection.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python +""" +Human-in-the-Loop (HIL) Data Collection with Policy Rollout. + +Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous +inference. For large models with high inference latency, use hil_data_collection_rtc.py. + +The workflow: +1. Policy runs autonomously +2. Press SPACE to pause - robot holds position +3. Press 'c' to take control - human provides RECOVERY + CORRECTION +4. Press → to end episode (save and continue to next) +5. Reset, then do next rollout + +Keyboard Controls: + SPACE - Pause policy (robot holds position, no recording) + c - Take control (start correction, recording resumes) + → - End episode (save and continue to next) + ← - Re-record episode + ESC - Stop recording and push dataset to hub + +Usage: + python examples/rac/hil_data_collection.py \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ + --dataset.repo_id=my_user/hil_dataset \ + --dataset.single_task="Pick up the cube" +""" + +import logging +import time +from dataclasses import dataclass +from pprint import pformat +from typing import Any + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig +from lerobot.datasets.image_writer import safe_stop_image_writer +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts +from lerobot.datasets.video_utils import VideoEncodingManager +import torch + +from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc import ActionInterpolator +from lerobot.policies.utils import make_robot_action +from lerobot.processor import PolicyProcessorPipeline +from lerobot.processor.rename_processor import rename_stats +from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.control_utils import is_headless, predict_action +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data + +from hil_utils import ( + HILDatasetConfig, + init_keyboard_listener, + make_identity_processors, + print_controls, + reset_loop, + teleop_disable_torque, + teleop_smooth_move_to, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class HILConfig: + robot: RobotConfig + teleop: TeleoperatorConfig + dataset: HILDatasetConfig + policy: PreTrainedConfig | None = None + interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x) + display_data: bool = True + play_sounds: bool = True + resume: bool = False + device: str = "cuda" + + def __post_init__(self): + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("policy.path is required") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + return ["policy"] + + +@safe_stop_image_writer +def rollout_loop( + robot: Robot, + teleop: Teleoperator, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + dataset: LeRobotDataset, + events: dict, + cfg: HILConfig, +): + """Rollout loop with standard synchronous inference.""" + fps = cfg.dataset.fps + device = get_safe_torch_device(cfg.device) + + policy.reset() + preprocessor.reset() + postprocessor.reset() + + frame_buffer = [] + teleop_disable_torque(teleop) + + was_paused = False + waiting_for_takeover = False + last_action: dict[str, Any] | None = None + robot_action: dict[str, Any] = {} + action_keys = sorted([k for k in robot.action_features.keys()]) + + interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) + control_interval = interpolator.get_control_interval(fps) + + timestamp = 0 + start_t = time.perf_counter() + + while timestamp < cfg.dataset.episode_time_s: + loop_start = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + break + + # Transition to paused state + if events["policy_paused"] and not was_paused: + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) + events["start_next_episode"] = False + waiting_for_takeover = True + was_paused = True + interpolator.reset() + + # Takeover + if waiting_for_takeover and events["start_next_episode"]: + teleop_disable_torque(teleop) + events["start_next_episode"] = False + events["correction_active"] = True + waiting_for_takeover = False + + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) + + if events["correction_active"]: + robot_action = teleop.get_action() + robot.send_action(robot_action) + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task}) + + elif waiting_for_takeover or events["policy_paused"]: + if last_action: + robot.send_action(last_action) + + else: + # Policy execution with optional interpolation + if interpolator.needs_new_action(): + action_values = predict_action( + observation=obs_frame, + policy=policy, + device=device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=cfg.dataset.single_task, + robot_type=robot.robot_type, + ) + robot_action = make_robot_action(action_values, dataset.features) + action_tensor = torch.tensor([robot_action[k] for k in action_keys]) + interpolator.add(action_tensor) + + interp_action = interpolator.get() + if interp_action is not None: + robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)} + robot.send_action(robot_action) + last_action = robot_action + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task}) + + if cfg.display_data and robot_action: + log_rerun_data(observation=obs_filtered, action=robot_action) + + dt = time.perf_counter() - loop_start + if (sleep_time := control_interval - dt) > 0: + precise_sleep(sleep_time) + timestamp = time.perf_counter() - start_t + + teleop_disable_torque(teleop) + + for frame in frame_buffer: + dataset.add_frame(frame) + + +@parser.wrap() +def hil_collect(cfg: HILConfig) -> LeRobotDataset: + """Main HIL data collection function.""" + init_logging() + logger.info(pformat(cfg.__dict__)) + + if cfg.display_data: + init_rerun(session_name="hil_collection") + + robot = make_robot_from_config(cfg.robot) + teleop = make_teleoperator_from_config(cfg.teleop) + + teleop_proc, obs_proc = make_identity_processors() + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_proc, + initial_features=create_initial_features(action=robot.action_features), + use_videos=cfg.dataset.video, + ), + aggregate_pipeline_dataset_features( + pipeline=obs_proc, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=cfg.dataset.video, + ), + ) + + dataset = None + listener = None + + try: + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + if hasattr(robot, "cameras") and robot.cameras: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + else: + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot.cameras if hasattr(robot, "cameras") else []), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + + policy = make_policy(cfg.policy, ds_meta=dataset.meta) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) + + robot.connect() + teleop.connect() + listener, events = init_keyboard_listener() + + print_controls(rtc=False) + print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Task: {cfg.dataset.single_task}") + print(f" Interpolation: {cfg.interpolation_multiplier}x\n") + + with VideoEncodingManager(dataset): + recorded = 0 + while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds) + + rollout_loop( + robot=robot, + teleop=teleop, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + events=events, + cfg=cfg, + ) + + if events["rerecord_episode"]: + log_say("Re-recording", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded += 1 + + if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + reset_loop(robot, teleop, events, cfg.dataset.fps) + + finally: + log_say("Stop recording", cfg.play_sounds, blocking=True) + + if dataset: + dataset.finalize() + + if robot.is_connected: + robot.disconnect() + if teleop.is_connected: + teleop.disconnect() + + if not is_headless() and listener: + listener.stop() + + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + return dataset + + +def main(): + from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() + hil_collect() + + +if __name__ == "__main__": + main() diff --git a/examples/rac/hil_data_collection_rtc.py b/examples/rac/hil_data_collection_rtc.py new file mode 100644 index 000000000..85390a35c --- /dev/null +++ b/examples/rac/hil_data_collection_rtc.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python +""" +Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC). + +Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models +(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks +asynchronously in a background thread for smooth robot control. + +For fast models (ACT, Diffusion), use hil_data_collection.py instead. + +The workflow: +1. Policy runs autonomously with RTC +2. Press SPACE to pause - robot holds position +3. Press 'c' to take control - human provides RECOVERY + CORRECTION +4. Press → to end episode (save and continue to next) +5. Reset, then do next rollout + +Keyboard Controls: + SPACE - Pause policy (robot holds position, no recording) + c - Take control (start correction, recording resumes) + → - End episode (save and continue to next) + ← - Re-record episode + ESC - Stop recording and push dataset to hub + +Usage: + python examples/rac/hil_data_collection_rtc.py \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \ + --dataset.repo_id=my_user/hil_rtc_dataset \ + --dataset.single_task="Pick up the cube" \ + --rtc.execution_horizon=20 +""" + +import logging +import math +import time +from dataclasses import dataclass, field +from pprint import pformat +from threading import Event, Lock, Thread +from typing import Any + +import torch + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig +from lerobot.datasets.image_writer import safe_stop_image_writer +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features +from lerobot.datasets.video_utils import VideoEncodingManager +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig +from lerobot.processor import PolicyProcessorPipeline +from lerobot.processor.rename_processor import rename_stats +from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config +from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.control_utils import is_headless +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data + +from hil_utils import ( + HILDatasetConfig, + init_keyboard_listener, + make_identity_processors, + print_controls, + reset_loop, + teleop_disable_torque, + teleop_smooth_move_to, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class HILRTCConfig: + robot: RobotConfig + teleop: TeleoperatorConfig + dataset: HILDatasetConfig + policy: PreTrainedConfig | None = None + rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20)) + interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x) + display_data: bool = True + play_sounds: bool = True + resume: bool = False + device: str = "cuda" + use_torch_compile: bool = False # First compile takes minutes, disable for real-time + + def __post_init__(self): + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if self.policy is None: + raise ValueError("policy.path is required") + self.rtc.enabled = True + + @classmethod + def __get_path_fields__(cls) -> list[str]: + return ["policy"] + + +class ThreadSafeRobot: + """Thread-safe wrapper for robot operations.""" + + def __init__(self, robot: Robot): + self._robot = robot + self._lock = Lock() + + def get_observation(self) -> dict[str, Any]: + with self._lock: + return self._robot.get_observation() + + def send_action(self, action: dict) -> None: + with self._lock: + self._robot.send_action(action) + + @property + def observation_features(self) -> dict: + return self._robot.observation_features + + @property + def action_features(self) -> dict: + return self._robot.action_features + + @property + def name(self) -> str: + return self._robot.name + + @property + def robot_type(self) -> str: + return self._robot.robot_type + + @property + def cameras(self): + return getattr(self._robot, "cameras", {}) + + +def rtc_inference_thread( + policy: PreTrainedPolicy, + obs_holder: dict, + hw_features: dict, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + queue_holder: dict, + shutdown_event: Event, + policy_active: Event, + cfg: HILRTCConfig, +): + """Background thread for RTC action chunk generation.""" + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.dataset.fps + threshold = 30 + + while not shutdown_event.is_set(): + if not policy_active.is_set(): + time.sleep(0.01) + continue + + queue = queue_holder.get("queue") + obs = obs_holder.get("obs") + if queue is None or obs is None: + time.sleep(0.01) + continue + + if queue.qsize() <= threshold: + try: + current_time = time.perf_counter() + idx_before = queue.get_action_index() + prev_actions = queue.get_left_over() + + latency = latency_tracker.max() + delay = math.ceil(latency / time_per_chunk) if latency else 0 + + obs_batch = build_dataset_frame(hw_features, obs, prefix="observation") + for name in obs_batch: + obs_batch[name] = torch.from_numpy(obs_batch[name]) + if "image" in name: + obs_batch[name] = obs_batch[name].float() / 255 + obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous() + obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device) + + obs_batch["task"] = [cfg.dataset.single_task] + obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown") + + preprocessed = preprocessor(obs_batch) + actions = policy.predict_action_chunk( + preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions + ) + + original = actions.squeeze(0).clone() + processed = postprocessor(actions).squeeze(0) + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + queue.merge(original, processed, new_delay, idx_before) + logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}") + except Exception as e: + logger.error(f"[RTC] Error: {e}") + time.sleep(0.5) + else: + time.sleep(0.01) + + +@safe_stop_image_writer +def rollout_loop( + robot: ThreadSafeRobot, + teleop: Teleoperator, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + dataset: LeRobotDataset, + events: dict, + cfg: HILRTCConfig, + queue_holder: dict, + obs_holder: dict, + policy_active: Event, + hw_features: dict, +): + """Rollout loop with RTC for asynchronous inference.""" + fps = cfg.dataset.fps + + policy.reset() + preprocessor.reset() + postprocessor.reset() + + frame_buffer = [] + teleop_disable_torque(teleop) + + was_paused = False + waiting_for_takeover = False + last_action: dict[str, Any] | None = None + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + + interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) + control_interval = interpolator.get_control_interval(fps) + + robot_action: dict[str, Any] = {} + timestamp = 0 + start_t = time.perf_counter() + + while timestamp < cfg.dataset.episode_time_s: + loop_start = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + break + + # Transition to paused state + if events["policy_paused"] and not was_paused: + policy_active.clear() + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) + events["start_next_episode"] = False + waiting_for_takeover = True + was_paused = True + interpolator.reset() + + # Takeover + if waiting_for_takeover and events["start_next_episode"]: + teleop_disable_torque(teleop) + events["start_next_episode"] = False + events["correction_active"] = True + waiting_for_takeover = False + + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) + + obs_holder["obs"] = obs_filtered + + if events["correction_active"]: + robot_action = teleop.get_action() + robot.send_action(robot_action) + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task}) + + elif waiting_for_takeover or events["policy_paused"]: + if last_action: + robot.send_action(last_action) + + else: + # Policy execution with RTC + if not policy_active.is_set(): + policy_active.set() + + queue = queue_holder["queue"] + + if interpolator.needs_new_action(): + new_action = queue.get() if queue else None + if new_action is not None: + interpolator.add(new_action.cpu()) + + action_tensor = interpolator.get() + if action_tensor is not None: + robot_action = {k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)} + robot.send_action(robot_action) + last_action = robot_action + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task}) + + if cfg.display_data and robot_action: + log_rerun_data(observation=obs_filtered, action=robot_action) + + dt = time.perf_counter() - loop_start + if (sleep_time := control_interval - dt) > 0: + precise_sleep(sleep_time) + timestamp = time.perf_counter() - start_t + + policy_active.clear() + teleop_disable_torque(teleop) + + for frame in frame_buffer: + dataset.add_frame(frame) + + +@parser.wrap() +def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset: + """Main HIL data collection function with RTC.""" + init_logging() + logger.info(pformat(cfg.__dict__)) + + if cfg.display_data: + init_rerun(session_name="hil_rtc_collection") + + robot_raw = make_robot_from_config(cfg.robot) + teleop = make_teleoperator_from_config(cfg.teleop) + + teleop_proc, obs_proc = make_identity_processors() + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_proc, + initial_features=create_initial_features(action=robot_raw.action_features), + use_videos=cfg.dataset.video, + ), + aggregate_pipeline_dataset_features( + pipeline=obs_proc, + initial_features=create_initial_features(observation=robot_raw.observation_features), + use_videos=cfg.dataset.video, + ), + ) + + dataset = None + listener = None + shutdown_event = Event() + policy_active = Event() + rtc_thread = None + + try: + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + if hasattr(robot_raw, "cameras") and robot_raw.cameras: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), + ) + else: + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot_raw.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + + # Load policy with RTC + policy_class = get_policy_class(cfg.policy.type) + policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) + if hasattr(policy_config, "compile_model"): + policy_config.compile_model = cfg.use_torch_compile + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config) + policy.config.rtc_config = cfg.rtc + if hasattr(policy, "init_rtc_processor"): + policy.init_rtc_processor() + policy = policy.to(cfg.device) + policy.eval() + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) + + robot_raw.connect() + robot = ThreadSafeRobot(robot_raw) + teleop.connect() + listener, events = init_keyboard_listener() + + queue_holder = {"queue": ActionQueue(cfg.rtc)} + obs_holder = {"obs": None, "robot_type": robot.robot_type} + hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation") + + rtc_thread = Thread( + target=rtc_inference_thread, + args=(policy, obs_holder, hw_features, preprocessor, postprocessor, + queue_holder, shutdown_event, policy_active, cfg), + daemon=True, + ) + rtc_thread.start() + + print_controls(rtc=True) + print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Task: {cfg.dataset.single_task}") + print(f" Interpolation: {cfg.interpolation_multiplier}x\n") + + with VideoEncodingManager(dataset): + recorded = 0 + while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds) + + queue_holder["queue"] = ActionQueue(cfg.rtc) + + rollout_loop( + robot=robot, + teleop=teleop, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset=dataset, + events=events, + cfg=cfg, + queue_holder=queue_holder, + obs_holder=obs_holder, + policy_active=policy_active, + hw_features=hw_features, + ) + + if events["rerecord_episode"]: + log_say("Re-recording", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded += 1 + + if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + reset_loop(robot, teleop, events, cfg.dataset.fps) + + finally: + log_say("Stop recording", cfg.play_sounds, blocking=True) + + shutdown_event.set() + policy_active.clear() + + if rtc_thread and rtc_thread.is_alive(): + rtc_thread.join(timeout=2.0) + + if dataset: + dataset.finalize() + + if robot_raw.is_connected: + robot_raw.disconnect() + if teleop.is_connected: + teleop.disconnect() + + if not is_headless() and listener: + listener.stop() + + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + return dataset + + +def main(): + from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() + hil_rtc_collect() + + +if __name__ == "__main__": + main() + diff --git a/examples/rac/hil_utils.py b/examples/rac/hil_utils.py new file mode 100644 index 000000000..b1b65f7ef --- /dev/null +++ b/examples/rac/hil_utils.py @@ -0,0 +1,251 @@ +"""Shared utilities for Human-in-the-Loop data collection scripts.""" + +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path + +from lerobot.processor import ( + IdentityProcessorStep, + RobotAction, + RobotObservation, + RobotProcessorPipeline, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.robots import Robot, RobotConfig +from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.utils.control_utils import is_headless +from lerobot.utils.robot_utils import precise_sleep + +logger = logging.getLogger(__name__) + + +@dataclass +class HILDatasetConfig: + repo_id: str + single_task: str + root: str | Path | None = None + fps: int = 30 + episode_time_s: float = 120 + num_episodes: int = 50 + video: bool = True + push_to_hub: bool = True + private: bool = False + tags: list[str] | None = None + num_image_writer_processes: int = 0 + num_image_writer_threads_per_camera: int = 4 + video_encoding_batch_size: int = 1 + rename_map: dict[str, str] = field(default_factory=dict) + + +def teleop_has_motor_control(teleop: Teleoperator) -> bool: + """Check if teleoperator has motor control capabilities.""" + return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque") + + +def teleop_disable_torque(teleop: Teleoperator) -> None: + """Disable teleop torque if supported.""" + if teleop_has_motor_control(teleop): + teleop.bus.disable_torque() + + +def teleop_enable_torque(teleop: Teleoperator) -> None: + """Enable teleop torque if supported.""" + if teleop_has_motor_control(teleop): + teleop.bus.enable_torque() + + +def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50): + """Smoothly move teleop to target position if motor control is available.""" + if not teleop_has_motor_control(teleop): + logger.warning("Teleop does not support motor control - cannot mirror robot position") + return + + teleop_enable_torque(teleop) + current = teleop.get_action() + steps = int(duration_s * fps) + + for step in range(steps + 1): + t = step / steps + interp = {} + for k in current: + if k in target_pos: + interp[k] = current[k] * (1 - t) + target_pos[k] * t + else: + interp[k] = current[k] + teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()}) + time.sleep(1 / fps) + + +def init_keyboard_listener(): + """Initialize keyboard listener with HIL controls.""" + events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + "policy_paused": False, + "correction_active": False, + "in_reset": False, + "start_next_episode": False, + } + + if is_headless(): + logger.warning("Headless environment - keyboard controls unavailable") + return None, events + + from pynput import keyboard + + def on_press(key): + try: + if events["in_reset"]: + if key in [keyboard.Key.space, keyboard.Key.right]: + print("\n[HIL] Starting next episode...") + events["start_next_episode"] = True + elif hasattr(key, "char") and key.char == "c": + events["start_next_episode"] = True + elif key == keyboard.Key.esc: + print("[HIL] ESC - Stop recording, pushing to hub...") + events["stop_recording"] = True + events["start_next_episode"] = True + else: + if key == keyboard.Key.space: + if not events["policy_paused"] and not events["correction_active"]: + print("\n[HIL] ⏸ PAUSED - Press 'c' to take control") + events["policy_paused"] = True + elif hasattr(key, "char") and key.char == "c": + if events["policy_paused"] and not events["correction_active"]: + print("\n[HIL] ▶ Taking control...") + events["start_next_episode"] = True + elif key == keyboard.Key.right: + print("[HIL] → End episode") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("[HIL] ← Re-record episode") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("[HIL] ESC - Stop recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Key error: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + start_pedal_listener(events) + return listener, events + + +def start_pedal_listener(events: dict): + """Start foot pedal listener if evdev is available.""" + import threading + + try: + from evdev import InputDevice, categorize, ecodes + except ImportError: + logger.info("[Pedal] evdev not installed - pedal support disabled") + return + + PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + KEY_LEFT, KEY_RIGHT = "KEY_A", "KEY_C" + + def pedal_reader(): + try: + dev = InputDevice(PEDAL_DEVICE) + print(f"[Pedal] Connected: {dev.name}") + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + key = categorize(ev) + code = key.keycode[0] if isinstance(key.keycode, (list, tuple)) else key.keycode + if key.keystate != 1: + continue + + if events["in_reset"]: + if code in [KEY_LEFT, KEY_RIGHT]: + events["start_next_episode"] = True + else: + if code == KEY_RIGHT: + if events["correction_active"]: + events["exit_early"] = True + elif not events["policy_paused"]: + events["policy_paused"] = True + elif code == KEY_LEFT: + if events["policy_paused"] and not events["correction_active"]: + events["start_next_episode"] = True + except (FileNotFoundError, PermissionError) as e: + logger.info(f"[Pedal] {e}") + + threading.Thread(target=pedal_reader, daemon=True).start() + + +def make_identity_processors(): + """Create identity processors for recording.""" + teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[IdentityProcessorStep()], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + return teleop_proc, obs_proc + + +def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int): + """Reset period where human repositions environment.""" + print("\n" + "=" * 60) + print(" [HIL] RESET") + print("=" * 60) + + events["in_reset"] = True + events["start_next_episode"] = False + + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) + + print(" Press any key/pedal to enable teleoperation") + while not events["start_next_episode"] and not events["stop_recording"]: + precise_sleep(0.05) + + if events["stop_recording"]: + return + + events["start_next_episode"] = False + teleop_disable_torque(teleop) + print(" Teleop enabled - press any key/pedal to start episode") + + while not events["start_next_episode"] and not events["stop_recording"]: + loop_start = time.perf_counter() + action = teleop.get_action() + robot.send_action(action) + precise_sleep(1 / fps - (time.perf_counter() - loop_start)) + + events["in_reset"] = False + events["start_next_episode"] = False + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + + +def print_controls(rtc: bool = False): + """Print control instructions.""" + print("\n" + "=" * 60) + print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")) + print("=" * 60) + print() + print(" Controls:") + print(" SPACE - Pause policy") + print(" c - Take control") + print(" → - End episode") + print(" ESC - Stop and push to hub") + print("=" * 60 + "\n") + diff --git a/examples/rac/rac_data_collection.py b/examples/rac/rac_data_collection.py deleted file mode 100644 index 62863b886..000000000 --- a/examples/rac/rac_data_collection.py +++ /dev/null @@ -1,639 +0,0 @@ -#!/usr/bin/env python -""" -RaC (Recovery and Correction) Data Collection with Policy Rollout + Human Intervention. - -This implements the RaC paradigm from "RaC: Robot Learning for Long-Horizon Tasks -by Scaling Recovery and Correction" (Hu et al., 2025) for LeRobot. - -RaC improves upon standard data collection (BC) and prior human-in-the-loop methods -(DAgger, HG-DAgger) by explicitly collecting recovery and correction behaviors: - -The workflow: -1. Policy runs autonomously -2. Press SPACE to pause - robot holds position -3. Press 'c' to take control - human provides RECOVERY + CORRECTION -4. Press → to end episode (save and continue to next) -5. Reset, then do next rollout - -Key RaC Rules: -- Rule 1 (Recover then Correct): Every intervention = recovery + correction (both human) -- Rule 2 (Terminate after Intervention): Episode ends after correction - -The recovery segment (teleoperating back to good state) is recorded as training data - -this teaches the policy how to recover from errors. - -Keyboard Controls: - SPACE - Pause policy (robot holds position, no recording) - c - Take control (start correction, recording resumes) - → - End episode (save and continue to next) - ← - Re-record episode - ESC - Stop recording and push dataset to hub - -Usage: - python examples/rac/rac_data_collection.py \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ - --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ - --dataset.repo_id=my_user/rac_dataset \ - --dataset.single_task="Pick up the cube" -""" - -import logging -import time -from dataclasses import dataclass, field -from pathlib import Path -from pprint import pformat -from typing import Any - -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.image_writer import safe_stop_image_writer -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts -from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import make_robot_action -from lerobot.processor import ( - IdentityProcessor, - PolicyAction, - PolicyProcessorPipeline, - RobotAction, - RobotObservation, - RobotProcessorPipeline, -) -from lerobot.processor.converters import ( - observation_to_transition, - robot_action_observation_to_transition, - transition_to_observation, - transition_to_robot_action, -) -from lerobot.processor.rename_processor import rename_stats -from lerobot.robots import Robot, RobotConfig, make_robot_from_config -from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config -from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.control_utils import is_headless, predict_action -from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say -from lerobot.utils.visualization_utils import init_rerun, log_rerun_data - - -@dataclass -class RaCDatasetConfig: - repo_id: str - single_task: str - root: str | Path | None = None - fps: int = 30 - episode_time_s: float = 120 - reset_time_s: float = 30 - num_episodes: int = 50 - video: bool = True - push_to_hub: bool = True - private: bool = False - tags: list[str] | None = None - num_image_writer_processes: int = 0 - num_image_writer_threads_per_camera: int = 4 - video_encoding_batch_size: int = 1 - rename_map: dict[str, str] = field(default_factory=dict) - - -@dataclass -class RaCConfig: - robot: RobotConfig - dataset: RaCDatasetConfig - policy: PreTrainedConfig - teleop: TeleoperatorConfig - display_data: bool = True - play_sounds: bool = True - resume: bool = False - - def __post_init__(self): - policy_path = parser.get_path_arg("policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - - @classmethod - def __get_path_fields__(cls) -> list[str]: - return ["policy"] - - -def init_rac_keyboard_listener(): - """Initialize keyboard listener with RaC-specific controls.""" - events = { - "exit_early": False, - "rerecord_episode": False, - "stop_recording": False, - "policy_paused": False, # SPACE pressed - policy paused, teleop tracking robot - "correction_active": False, # 'c' pressed - human controlling, recording correction - "in_reset": False, # True during reset period - "start_next_episode": False, # Signal to start next episode - } - - if is_headless(): - logging.warning("Headless environment - keyboard controls unavailable") - return None, events - - from pynput import keyboard - - def on_press(key): - try: - if events["in_reset"]: - # During reset: any action key starts next episode - if key == keyboard.Key.space or key == keyboard.Key.right: - print("\n[RaC] Starting next episode...") - events["start_next_episode"] = True - elif hasattr(key, 'char') and key.char == 'c': - print("\n[RaC] Starting next episode...") - events["start_next_episode"] = True - elif key == keyboard.Key.esc: - print("[RaC] ESC - Stop recording, pushing to hub...") - events["stop_recording"] = True - events["start_next_episode"] = True - else: - # During episode - if key == keyboard.Key.space: - if not events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position") - print(" Press 'c' or START to take control") - events["policy_paused"] = True - elif hasattr(key, 'char') and key.char == 'c': - if events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ▶ START pressed - taking control") - events["start_next_episode"] = True - elif key == keyboard.Key.right: - print("[RaC] → End episode") - events["exit_early"] = True - elif key == keyboard.Key.left: - print("[RaC] ← Re-record episode") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("[RaC] ESC - Stop recording, pushing to hub...") - events["stop_recording"] = True - events["exit_early"] = True - except Exception as e: - print(f"Key error: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - start_pedal_listener(events) - - return listener, events - - -def start_pedal_listener(events: dict): - """Start foot pedal listener thread if evdev is available.""" - import threading - - try: - from evdev import InputDevice, ecodes - except ImportError: - logging.info("[Pedal] evdev not installed - pedal support disabled") - return - - PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" - KEY_LEFT = "KEY_A" # Left pedal - KEY_RIGHT = "KEY_C" # Right pedal - - def pedal_reader(): - try: - dev = InputDevice(PEDAL_DEVICE) - print(f"[Pedal] Connected: {dev.name}") - print(f"[Pedal] Right=pause/next, Left=take control/start") - - for ev in dev.read_loop(): - if ev.type != ecodes.EV_KEY: - continue - - from evdev import categorize - key = categorize(ev) - code = key.keycode - if isinstance(code, (list, tuple)): - code = code[0] - - # Only trigger on key down - if key.keystate != 1: - continue - - if events["in_reset"]: - # During reset: either pedal starts next episode - if code in [KEY_LEFT, KEY_RIGHT]: - print("\n[Pedal] Starting next episode...") - events["start_next_episode"] = True - else: - # During episode - if code == KEY_RIGHT: - # Right pedal: SPACE (pause) when running, → (next) when in correction - if events["correction_active"]: - print("\n[Pedal] → End episode") - events["exit_early"] = True - elif not events["policy_paused"]: - print("\n[Pedal] ⏸ PAUSED - Policy stopped, teleop moving to robot") - print(" Press left pedal to take control") - events["policy_paused"] = True - - elif code == KEY_LEFT: - # Left pedal: START (take control) when paused - if events["policy_paused"] and not events["correction_active"]: - print("\n[Pedal] ▶ START pressed - taking control") - events["start_next_episode"] = True - - except FileNotFoundError: - logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") - except PermissionError: - logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}") - except Exception as e: - logging.debug(f"[Pedal] Error: {e}") - - thread = threading.Thread(target=pedal_reader, daemon=True) - thread.start() - - -def make_identity_processors(): - """Create identity processors for RaC recording.""" - teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - steps=[IdentityProcessor()], - to_transition=robot_action_observation_to_transition, - to_output=transition_to_robot_action, - ) - robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - steps=[IdentityProcessor()], - to_transition=robot_action_observation_to_transition, - to_output=transition_to_robot_action, - ) - obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( - steps=[IdentityProcessor()], - to_transition=observation_to_transition, - to_output=transition_to_observation, - ) - return teleop_proc, robot_proc, obs_proc - - -def move_robot_to_zero(robot: Robot, duration_s: float = 2.0, fps: int = 50): - """Smoothly move all robot joints to zero position.""" - obs = robot.get_observation() - current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} - target_pos = {k: 0.0 for k in current_pos} - - print(f"[RaC] Moving robot to zero position ({duration_s}s)...") - steps = int(duration_s * fps) - for step in range(steps + 1): - t = step / steps - interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos} - robot.send_action(interp_pos) - time.sleep(1 / fps) - print("[RaC] Robot at zero position.") - -@safe_stop_image_writer -def rac_rollout_loop( - robot: Robot, - teleop: Teleoperator, - policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], - dataset: LeRobotDataset, - events: dict, - fps: int, - control_time_s: float, - single_task: str, - display_data: bool = True, -) -> dict: - """ - RaC rollout loop with two-stage intervention: - - 1. Policy runs autonomously (recording) - 2. SPACE: Policy pauses (NOT recording) - robot holds position - 3. 'c': Human takes control (recording correction) - 4. →: End episode - """ - policy.reset() - preprocessor.reset() - postprocessor.reset() - - device = get_safe_torch_device(policy.config.device) - frame_buffer = [] - - stats = { - "total_frames": 0, - "autonomous_frames": 0, - "paused_frames": 0, - "correction_frames": 0, - } - - last_robot_action = None - was_paused = False - was_correction_active = False - waiting_for_takeover = False - timestamp = 0 - start_t = time.perf_counter() - - while timestamp < control_time_s: - loop_start = time.perf_counter() - - if events["exit_early"]: - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - break - - # Detect transition to paused state - if events["policy_paused"] and not was_paused: - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} - print("[RaC] Moving teleop to robot position (2s smooth transition)...") - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - print("[RaC] Teleop aligned. Press START to take control.") - events["start_next_episode"] = False - waiting_for_takeover = True - was_paused = True - - # Wait for start button before enabling correction mode - if waiting_for_takeover and events["start_next_episode"]: - print("[RaC] Start pressed - enabling teleop control...") - events["start_next_episode"] = False - events["correction_active"] = True - waiting_for_takeover = False - was_correction_active = True - - obs = robot.get_observation() - obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR) - - if events["correction_active"]: - # Human controlling - record correction data - robot_action = teleop.get_action() - robot.send_action(robot_action) - stats["correction_frames"] += 1 - - # Record this frame - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - - elif waiting_for_takeover: - # Waiting for START - policy stopped, no recording, robot holds position - if last_robot_action is not None: - robot.send_action(last_robot_action) - stats["paused_frames"] += 1 - - elif events["policy_paused"]: - # Paused and user acknowledged - hold last position, don't record - if last_robot_action is not None: - robot.send_action(last_robot_action) - stats["paused_frames"] += 1 - robot_action = last_robot_action - - else: - # Normal policy execution - record - action_values = predict_action( - observation=obs_frame, - policy=policy, - device=device, - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) - robot_action: RobotAction = make_robot_action(action_values, dataset.features) - robot.send_action(robot_action) - last_robot_action = robot_action - stats["autonomous_frames"] += 1 - - # Record this frame - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - - if display_data and robot_action is not None: - log_rerun_data(observation=obs, action=robot_action) - - dt = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt) - timestamp = time.perf_counter() - start_t - - for frame in frame_buffer: - dataset.add_frame(frame) - - return stats - - -def reset_loop( - robot: Robot, - teleop: Teleoperator, - events: dict, - fps: int, -): - """Reset period where human repositions environment. Two-stage: enable teleop, then start episode.""" - print("\n" + "=" * 65) - print(" [RaC] RESET - Moving teleop to robot position...") - print("=" * 65) - - # Enter reset mode - events["in_reset"] = True - events["start_next_episode"] = False - - # Move teleop to match robot position to avoid sudden jumps - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos")} - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - - # Stage 1: Wait for user to press start to enable teleoperation - print(" Teleop aligned. Press any key/pedal to enable teleoperation") - while not events["start_next_episode"] and not events["stop_recording"]: - precise_sleep(0.05) - - if events["stop_recording"]: - return - - # Stage 2: Enable teleop and let user move robot to starting position - events["start_next_episode"] = False - teleop.disable_torque() - print(" Teleop enabled - move robot to starting position") - print(" Press any key/pedal to start next episode") - - # Wait for user to signal ready for next episode - while not events["start_next_episode"] and not events["stop_recording"]: - loop_start = time.perf_counter() - - action = teleop.get_action() - robot.send_action(action) - - dt = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt) - - # Exit reset mode and clear flags for next episode - events["in_reset"] = False - events["start_next_episode"] = False - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - - -@parser.wrap() -def rac_collect(cfg: RaCConfig) -> LeRobotDataset: - """Main RaC data collection function.""" - init_logging() - logging.info(pformat(cfg.__dict__)) - - if cfg.display_data: - init_rerun(session_name="rac_collection") - - robot = make_robot_from_config(cfg.robot) - teleop = make_teleoperator_from_config(cfg.teleop) - - teleop_proc, robot_proc, obs_proc = make_identity_processors() - - dataset_features = combine_feature_dicts( - aggregate_pipeline_dataset_features( - pipeline=teleop_proc, - initial_features=create_initial_features(action=robot.action_features), - use_videos=cfg.dataset.video, - ), - aggregate_pipeline_dataset_features( - pipeline=obs_proc, - initial_features=create_initial_features(observation=robot.observation_features), - use_videos=cfg.dataset.video, - ), - ) - - dataset = None - listener = None - - try: - if cfg.resume: - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) - if hasattr(robot, "cameras") and robot.cameras: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - ) - else: - dataset = LeRobotDataset.create( - cfg.dataset.repo_id, - cfg.dataset.fps, - root=cfg.dataset.root, - robot_type=robot.name, - features=dataset_features, - use_videos=cfg.dataset.video, - image_writer_processes=cfg.dataset.num_image_writer_processes, - image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera - * len(robot.cameras if hasattr(robot, "cameras") else []), - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) - - policy = make_policy(cfg.policy, ds_meta=dataset.meta) - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.policy.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - - robot.connect() - teleop.connect() - listener, events = init_rac_keyboard_listener() - - print("\n" + "=" * 65) - print(" RaC (Recovery and Correction) Data Collection") - print("=" * 65) - print(" Policy runs autonomously until you intervene.") - print() - print(" Controls:") - print(" SPACE - Pause policy (robot holds position, no recording)") - print(" c - Take control (start correction, recording)") - print(" → - End episode (save)") - print(" ← - Re-record episode") - print(" ESC - Stop session and push to hub") - print("=" * 65 + "\n") - - with VideoEncodingManager(dataset): - recorded = 0 - while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) - - move_robot_to_zero(robot, duration_s=2.0, fps=cfg.dataset.fps) - - stats = rac_rollout_loop( - robot=robot, - teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - events=events, - fps=cfg.dataset.fps, - control_time_s=cfg.dataset.episode_time_s, - single_task=cfg.dataset.single_task, - display_data=cfg.display_data, - ) - - logging.info(f"Episode stats: {stats}") - - if events["rerecord_episode"]: - log_say("Re-recording", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded += 1 - - # Reset between episodes - if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - reset_loop( - robot=robot, - teleop=teleop, - events=events, - fps=cfg.dataset.fps, - ) - - finally: - log_say("Stop recording", cfg.play_sounds, blocking=True) - - if dataset: - dataset.finalize() - - if robot.is_connected: - robot.disconnect() - if teleop.is_connected: - teleop.disconnect() - - if not is_headless() and listener: - listener.stop() - - if cfg.dataset.push_to_hub: - dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) - - return dataset - - -def main(): - from lerobot.utils.import_utils import register_third_party_plugins - - register_third_party_plugins() - rac_collect() - - -if __name__ == "__main__": - main() - - diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py deleted file mode 100644 index f89248ff8..000000000 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ /dev/null @@ -1,889 +0,0 @@ -#!/usr/bin/env python -""" -RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC. - -This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution. -RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion -despite high inference latency by asynchronously generating action chunks. - -The workflow: -1. Policy runs autonomously with RTC (teleop is idle/free) -2. Press SPACE to pause - teleop moves to match robot position -3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION -4. Press → to end episode (save and continue to next) -5. Reset, then do next rollout - -Usage: - python examples/rac/rac_data_collection_openarms_rtc.py \ - --robot.port_right=can0 \ - --robot.port_left=can1 \ - --teleop.port_right=/dev/ttyUSB0 \ - --teleop.port_left=/dev/ttyUSB1 \ - --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ - --dataset.repo_id=my_user/rac_openarms_dataset \ - --dataset.single_task="Pick up the cube" -""" - -import logging -import math -import time -from dataclasses import dataclass, field -from pathlib import Path -from pprint import pformat -from threading import Event, Lock, Thread -from typing import Any - -import torch -from torch import Tensor - -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import RTCAttentionSchedule -from lerobot.datasets.image_writer import safe_stop_image_writer -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features -from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import get_policy_class, make_pre_post_processors -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rtc.action_queue import ActionQueue -from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.policies.rtc.latency_tracker import LatencyTracker -from lerobot.policies.utils import make_robot_action -from lerobot.processor import ( - IdentityProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - RobotAction, - RobotObservation, - RobotProcessorPipeline, -) -from lerobot.processor.converters import ( - observation_to_transition, - robot_action_observation_to_transition, - transition_to_observation, - transition_to_robot_action, -) -from lerobot.processor.rename_processor import rename_stats -from lerobot.robots import Robot, RobotConfig, make_robot_from_config -from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401 -from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config -from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401 -from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.control_utils import is_headless, predict_action -from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say -from lerobot.utils.visualization_utils import init_rerun, log_rerun_data - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -# ============================================================================ -# Configuration -# ============================================================================ - -@dataclass -class RaCRTCDatasetConfig: - repo_id: str = "lerobot/rac_openarms_rtc" - single_task: str = "default task" - root: str | Path | None = None - fps: int = 30 - episode_time_s: float = 500 - reset_time_s: float = 30 - num_episodes: int = 50 - video: bool = True - push_to_hub: bool = True - private: bool = False - tags: list[str] | None = None - num_image_writer_processes: int = 0 - num_image_writer_threads_per_camera: int = 4 - video_encoding_batch_size: int = 1 - rename_map: dict[str, str] = field(default_factory=dict) - - -@dataclass -class RaCRTCConfig: - robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig( - port_left="can0", - port_right="can1", - )) - teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig( - port_left="/dev/ttyUSB1", - port_right="/dev/ttyUSB0", - )) - dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig) - policy: PreTrainedConfig | None = None - - rtc: RTCConfig = field(default_factory=lambda: RTCConfig( - enabled=True, - execution_horizon=20, - max_guidance_weight=5.0, - prefix_attention_schedule=RTCAttentionSchedule.LINEAR, - )) - - interpolation: bool = True - display_data: bool = True - play_sounds: bool = True - resume: bool = False - device: str = "cuda" - action_queue_size_to_get_new_actions: int = 30 - - # Torch compile is disabled by default for real-time inference - # First inference with compile takes minutes to compile kernels - use_torch_compile: bool = False - - def __post_init__(self): - policy_path = parser.get_path_arg("policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - if self.policy is None: - raise ValueError("policy.path is required") - - @classmethod - def __get_path_fields__(cls) -> list[str]: - return ["policy"] - - -# ============================================================================ -# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py) -# ============================================================================ - -class RobotWrapper: - """Thread-safe wrapper for robot operations.""" - - def __init__(self, robot: Robot): - self.robot = robot - self.lock = Lock() - - def get_observation(self) -> dict[str, Tensor]: - with self.lock: - return self.robot.get_observation() - - def send_action(self, action: dict) -> None: - with self.lock: - self.robot.send_action(action) - - @property - def observation_features(self) -> dict: - return self.robot.observation_features - - @property - def action_features(self) -> dict: - return self.robot.action_features - - @property - def name(self) -> str: - return self.robot.name - - @property - def robot_type(self) -> str: - return self.robot.robot_type - - -# ============================================================================ -# Keyboard/Pedal Listeners -# ============================================================================ - -def init_rac_keyboard_listener(): - """Initialize keyboard listener with RaC-specific controls.""" - events = { - "exit_early": False, - "rerecord_episode": False, - "stop_recording": False, - "policy_paused": False, - "correction_active": False, - "in_reset": False, - "start_next_episode": False, - } - - if is_headless(): - logging.warning("Headless environment - keyboard controls unavailable") - return None, events - - from pynput import keyboard - - def on_press(key): - try: - if events["in_reset"]: - if key == keyboard.Key.space or key == keyboard.Key.right: - print("\n[RaC] Starting next episode...") - events["start_next_episode"] = True - elif hasattr(key, 'char') and key.char == 'c': - print("\n[RaC] Starting next episode...") - events["start_next_episode"] = True - elif key == keyboard.Key.esc: - print("[RaC] ESC - Stop recording, pushing to hub...") - events["stop_recording"] = True - events["start_next_episode"] = True - else: - if key == keyboard.Key.space: - if not events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position") - print(" Press 'c' or START to take control") - events["policy_paused"] = True - elif hasattr(key, 'char') and key.char == 'c': - if events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ▶ START pressed - taking control") - events["start_next_episode"] = True - elif key == keyboard.Key.right: - print("[RaC] → End episode") - events["exit_early"] = True - elif key == keyboard.Key.left: - print("[RaC] ← Re-record episode") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("[RaC] ESC - Stop recording, pushing to hub...") - events["stop_recording"] = True - events["exit_early"] = True - except Exception as e: - print(f"Key error: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - start_pedal_listener(events) - - return listener, events - - -def start_pedal_listener(events: dict): - """Start foot pedal listener thread if evdev is available.""" - import threading - - try: - from evdev import InputDevice, ecodes # noqa: F401 - except ImportError: - logging.info("[Pedal] evdev not installed - pedal support disabled") - return - - PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" - KEY_LEFT = "KEY_A" - KEY_RIGHT = "KEY_C" - - def pedal_reader(): - try: - dev = InputDevice(PEDAL_DEVICE) - print(f"[Pedal] Connected: {dev.name}") - - for ev in dev.read_loop(): - if ev.type != ecodes.EV_KEY: - continue - - from evdev import categorize # noqa: F401 - key = categorize(ev) - code = key.keycode - if isinstance(code, (list, tuple)): - code = code[0] - - if key.keystate != 1: - continue - - if events["in_reset"]: - if code in [KEY_LEFT, KEY_RIGHT]: - events["start_next_episode"] = True - else: - if code == KEY_RIGHT: - if events["correction_active"]: - events["exit_early"] = True - elif not events["policy_paused"]: - events["policy_paused"] = True - elif code == KEY_LEFT: - if events["policy_paused"] and not events["correction_active"]: - events["start_next_episode"] = True - - except FileNotFoundError: - logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") - except PermissionError: - logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}") - except Exception as e: - logging.debug(f"[Pedal] Error: {e}") - - thread = threading.Thread(target=pedal_reader, daemon=True) - thread.start() - - -def make_identity_processors(): - """Create identity processors for RaC recording.""" - teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - steps=[IdentityProcessorStep()], - to_transition=robot_action_observation_to_transition, - to_output=transition_to_robot_action, - ) - robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( - steps=[IdentityProcessorStep()], - to_transition=robot_action_observation_to_transition, - to_output=transition_to_robot_action, - ) - obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( - steps=[IdentityProcessorStep()], - to_transition=observation_to_transition, - to_output=transition_to_observation, - ) - return teleop_proc, robot_proc, obs_proc - - -# ============================================================================ -# RTC Inference Thread (from evaluate_with_rtc.py) -# ============================================================================ - -def rtc_inference_thread( - policy, - obs_holder: dict, - hw_features: dict, - preprocessor, - postprocessor, - queue_holder: dict, - shutdown_event: Event, - policy_active: Event, - cfg: RaCRTCConfig, -): - """Background thread that generates action chunks using RTC.""" - try: - logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========") - logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys") - - latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / cfg.dataset.fps - policy_device = policy.config.device - - get_actions_threshold = cfg.action_queue_size_to_get_new_actions - if not cfg.rtc.enabled: - get_actions_threshold = 0 - - inference_count = 0 - wait_logged = False - - while not shutdown_event.is_set(): - if not policy_active.is_set(): - if not wait_logged: - logger.info("[RTC] Waiting for policy_active...") - wait_logged = True - time.sleep(0.01) - continue - - wait_logged = False - - action_queue = queue_holder["queue"] - if action_queue is None: - logger.warning("[RTC] queue_holder['queue'] is None!") - time.sleep(0.01) - continue - - obs_filtered = obs_holder.get("obs") - if obs_filtered is None: - logger.warning("[RTC] obs_holder['obs'] is None!") - time.sleep(0.01) - continue - - qsize = action_queue.qsize() - if qsize <= get_actions_threshold: - try: - if inference_count == 0: - logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}") - - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() - - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 - - obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") - - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255 - obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous() - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device) - - obs_with_policy_features["task"] = [cfg.dataset.single_task] - obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") - - preprocessed_obs = preprocessor(obs_with_policy_features) - - actions = policy.predict_action_chunk( - preprocessed_obs, - inference_delay=inference_delay, - prev_chunk_left_over=prev_actions, - ) - - original_actions = actions.squeeze(0).clone() - postprocessed_actions = postprocessor(actions).squeeze(0) - - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - latency_tracker.add(new_latency) - - action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference) - - inference_count += 1 - logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}") - except Exception as e: - logger.error(f"[RTC] Inference error: {e}") - import traceback - traceback.print_exc() - time.sleep(1.0) - else: - time.sleep(0.01) - - logger.info("[RTC] Inference thread shutting down") - except Exception as e: - logger.error(f"[RTC] THREAD CRASHED: {e}") - import traceback - traceback.print_exc() - - -# ============================================================================ -# Main Rollout Loop -# ============================================================================ - -@safe_stop_image_writer -def rac_rtc_rollout_loop( - robot: RobotWrapper, - teleop: Teleoperator, - policy: PreTrainedPolicy, - preprocessor, - postprocessor, - dataset: LeRobotDataset, - events: dict, - cfg: RaCRTCConfig, - queue_holder: dict, - obs_holder: dict, # Main loop writes obs here for RTC thread to read - policy_active: Event, - hw_features: dict, -) -> dict: - """RaC rollout loop with RTC for smooth policy execution.""" - fps = cfg.dataset.fps - single_task = cfg.dataset.single_task - control_time_s = cfg.dataset.episode_time_s - device = get_safe_torch_device(cfg.device) - - # Reset policy state - policy.reset() - preprocessor.reset() - postprocessor.reset() - - frame_buffer = [] - stats = { - "total_frames": 0, - "autonomous_frames": 0, - "paused_frames": 0, - "correction_frames": 0, - } - - teleop.disable_torque() - was_paused = False - waiting_for_takeover = False - - # Action keys for converting tensor to dict - action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - - # Interpolation state - prev_action: Tensor | None = None - interpolated_actions: list[Tensor] = [] - interp_idx = 0 - - if cfg.interpolation: - control_interval = 1.0 / (fps * 2) # 2x rate - else: - control_interval = 1.0 / fps - - robot_action = {} - timestamp = 0 - start_t = time.perf_counter() - - while timestamp < control_time_s: - loop_start = time.perf_counter() - - if events["exit_early"]: - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - break - - # State transition: entering paused state - if events["policy_paused"] and not was_paused: - policy_active.clear() # Stop RTC inference - obs = robot.get_observation() - obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} - robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} - print("[RaC] Moving teleop to robot position...") - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - print("[RaC] Teleop aligned. Press 'c' to take control.") - events["start_next_episode"] = False - waiting_for_takeover = True - was_paused = True - # Reset interpolation - prev_action = None - interpolated_actions = [] - interp_idx = 0 - - # Wait for takeover - if waiting_for_takeover and events["start_next_episode"]: - print("[RaC] Taking control...") - teleop.disable_torque() - events["start_next_episode"] = False - events["correction_active"] = True - waiting_for_takeover = False - - # Get observation (ONLY the main loop reads from robot!) - obs = robot.get_observation() - obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} - obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) - - # Share observation with RTC thread (thread reads, main loop writes) - obs_holder["obs"] = obs_filtered - - if events["correction_active"]: - # Human controlling - robot_action = teleop.get_action() - for key in robot_action: - if "gripper" in key: - robot_action[key] = -0.65 * robot_action[key] - robot.send_action(robot_action) - stats["correction_frames"] += 1 - - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - - elif waiting_for_takeover: - stats["paused_frames"] += 1 - - elif events["policy_paused"]: - robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} - teleop.send_feedback(robot_pos) - stats["paused_frames"] += 1 - - else: - # Policy execution with RTC - if not policy_active.is_set(): - policy_active.set() - logger.info("[ROLLOUT] Policy activated, waiting for first actions...") - - action_queue = queue_holder["queue"] - - # Get action from queue (with interpolation) - if interp_idx >= len(interpolated_actions): - new_action = action_queue.get() if action_queue else None - - # Log queue status periodically - if stats["autonomous_frames"] == 0 and new_action is None: - qsize = action_queue.qsize() if action_queue else -1 - if timestamp < 0.5 or int(timestamp * 10) % 10 == 0: - logger.info(f"[ROLLOUT] Waiting for actions... queue_size={qsize}, obs_set={obs_holder.get('obs') is not None}") - - if new_action is not None: - current_action = new_action.cpu() - - if cfg.interpolation and prev_action is not None: - mid = prev_action + 0.5 * (current_action - prev_action) - interpolated_actions = [mid, current_action] - else: - interpolated_actions = [current_action] - - prev_action = current_action - interp_idx = 0 - - if stats["autonomous_frames"] == 0: - logger.info(f"[ROLLOUT] Got first action! Starting robot motion.") - - if interp_idx < len(interpolated_actions): - action_to_send = interpolated_actions[interp_idx] - interp_idx += 1 - - robot_action = {} - for i, key in enumerate(action_keys): - if i < len(action_to_send): - robot_action[key] = action_to_send[i].item() - - robot.send_action(robot_action) - stats["autonomous_frames"] += 1 - - # Record at original fps - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - - if cfg.display_data: - log_rerun_data(observation=obs_filtered, action=robot_action) - - dt = time.perf_counter() - loop_start - sleep_time = control_interval - dt - if sleep_time > 0: - precise_sleep(sleep_time) - timestamp = time.perf_counter() - start_t - - policy_active.clear() - teleop.disable_torque() - - for frame in frame_buffer: - dataset.add_frame(frame) - - return stats - - -def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int): - """Reset period where human repositions environment.""" - print("\n" + "=" * 65) - print(" [RaC] RESET") - print("=" * 65) - - events["in_reset"] = True - events["start_next_episode"] = False - - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - - print(" Press any key/pedal to enable teleoperation") - while not events["start_next_episode"] and not events["stop_recording"]: - precise_sleep(0.05) - - if events["stop_recording"]: - return - - events["start_next_episode"] = False - teleop.disable_torque() - print(" Teleop enabled - press any key/pedal to start episode") - - while not events["start_next_episode"] and not events["stop_recording"]: - loop_start = time.perf_counter() - action = teleop.get_action() - for key in action: - if "gripper" in key: - action[key] = -0.65 * action[key] - robot.send_action(action) - dt = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt) - - events["in_reset"] = False - events["start_next_episode"] = False - events["exit_early"] = False - events["policy_paused"] = False - events["correction_active"] = False - - -# ============================================================================ -# Main Entry Point -# ============================================================================ - -@parser.wrap() -def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: - """Main RaC data collection function with RTC.""" - init_logging() - logging.info(pformat(cfg.__dict__)) - - if cfg.display_data: - init_rerun(session_name="rac_rtc_collection_openarms") - - robot_raw = make_robot_from_config(cfg.robot) - teleop = make_teleoperator_from_config(cfg.teleop) - - teleop_proc, robot_proc, obs_proc = make_identity_processors() - - dataset_features = combine_feature_dicts( - aggregate_pipeline_dataset_features( - pipeline=teleop_proc, - initial_features=create_initial_features(action=robot_raw.action_features), - use_videos=cfg.dataset.video, - ), - aggregate_pipeline_dataset_features( - pipeline=obs_proc, - initial_features=create_initial_features(observation=robot_raw.observation_features), - use_videos=cfg.dataset.video, - ), - ) - - dataset = None - listener = None - shutdown_event = Event() - policy_active = Event() - rtc_thread = None - - try: - if cfg.resume: - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) - if hasattr(robot_raw, "cameras") and robot_raw.cameras: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), - ) - else: - dataset = LeRobotDataset.create( - cfg.dataset.repo_id, - cfg.dataset.fps, - root=cfg.dataset.root, - robot_type=robot_raw.name, - features=dataset_features, - use_videos=cfg.dataset.video, - image_writer_processes=cfg.dataset.num_image_writer_processes, - image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera - * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), - batch_encoding_size=cfg.dataset.video_encoding_batch_size, - ) - - # Load policy - logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") - policy_class = get_policy_class(cfg.policy.type) - - # Override compile_model for real-time inference (first compile takes minutes) - policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) - if cfg.policy.type in ["pi05", "pi0"]: - policy_config.compile_model = cfg.use_torch_compile - logger.info(f"Set compile_model={cfg.use_torch_compile} for real-time inference") - - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config) - policy.config.rtc_config = cfg.rtc - policy.init_rtc_processor() - policy = policy.to(cfg.device) - policy.eval() - logger.info(f"Policy loaded: {policy.name}") - - # Setup preprocessor/postprocessor - hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation") - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - - # Connect robot and wrap for thread safety - robot_raw.connect() - robot = RobotWrapper(robot_raw) - - teleop.connect() - listener, events = init_rac_keyboard_listener() - - # Shared state holders (main loop writes, RTC thread reads) - queue_holder = {"queue": ActionQueue(cfg.rtc)} - obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs - - # Start RTC inference thread - # NOTE: Thread does NOT access robot directly - reads from obs_holder - rtc_thread = Thread( - target=rtc_inference_thread, - args=( - policy, - obs_holder, # Thread reads obs from here (set by main loop) - hw_features, - preprocessor, - postprocessor, - queue_holder, - shutdown_event, - policy_active, - cfg, - ), - daemon=True, - name="RTCInference", - ) - rtc_thread.start() - logger.info("Started RTC inference thread") - - print("\n" + "=" * 65) - print(" RaC Data Collection with RTC") - print("=" * 65) - print(f" Policy: {cfg.policy.pretrained_path}") - print(f" Task: {cfg.dataset.single_task}") - print(f" FPS: {cfg.dataset.fps}") - print(f" Interpolation: {cfg.interpolation}") - print() - print(" Controls:") - print(" SPACE - Pause policy") - print(" c - Take control") - print(" → - End episode") - print(" ESC - Stop and push to hub") - print("=" * 65 + "\n") - - with VideoEncodingManager(dataset): - recorded = 0 - while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) - - # Fresh action queue per episode (update holder so thread sees it) - queue_holder["queue"] = ActionQueue(cfg.rtc) - - logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") - - stats = rac_rtc_rollout_loop( - robot=robot, - teleop=teleop, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset=dataset, - events=events, - cfg=cfg, - queue_holder=queue_holder, - obs_holder=obs_holder, - policy_active=policy_active, - hw_features=hw_features, - ) - - logging.info(f"Episode stats: {stats}") - - if events["rerecord_episode"]: - log_say("Re-recording", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded += 1 - - if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - reset_loop(robot, teleop, events, cfg.dataset.fps) - - finally: - log_say("Stop recording", cfg.play_sounds, blocking=True) - - shutdown_event.set() - policy_active.clear() - - if rtc_thread and rtc_thread.is_alive(): - rtc_thread.join(timeout=2.0) - - if dataset: - dataset.finalize() - - if robot_raw.is_connected: - robot_raw.disconnect() - if teleop.is_connected: - teleop.disconnect() - - if not is_headless() and listener: - listener.stop() - - if cfg.dataset.push_to_hub: - dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) - - return dataset - - -def main(): - from lerobot.utils.import_utils import register_third_party_plugins - register_third_party_plugins() - rac_rtc_collect() - - -if __name__ == "__main__": - main() - diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 6f051485a..5bb9fbbe3 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -83,9 +83,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features from lerobot.policies.factory import get_policy_class, make_pre_post_processors -from lerobot.policies.rtc.action_queue import ActionQueue -from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig from lerobot.processor.factory import ( make_default_robot_action_processor, make_default_robot_observation_processor, @@ -151,6 +149,7 @@ class RTCDemoConfig(HubMixin): # Demo parameters duration: float = 30.0 # Duration to run the demo (seconds) fps: float = 10.0 # Action execution frequency (Hz) + interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x) # Compute device device: str | None = None # Device to run on (cuda, cpu, auto) @@ -351,20 +350,22 @@ def actor_control( logger.info("[ACTOR] Starting actor thread") action_count = 0 - action_interval = 1.0 / cfg.fps + interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) + action_interval = interpolator.get_control_interval(cfg.fps) while not shutdown_event.is_set(): start_time = time.perf_counter() - # Try to get an action from the queue with timeout - action = action_queue.get() + if interpolator.needs_new_action(): + new_action = action_queue.get() + if new_action is not None: + interpolator.add(new_action.cpu()) + action = interpolator.get() if action is not None: - action = action.cpu() action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())} action_processed = robot_action_processor((action_dict, None)) robot.send_action(action_processed) - action_count += 1 dt_s = time.perf_counter() - start_time diff --git a/src/lerobot/policies/rtc/__init__.py b/src/lerobot/policies/rtc/__init__.py new file mode 100644 index 000000000..9d1620b6e --- /dev/null +++ b/src/lerobot/policies/rtc/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Real-Time Chunking (RTC) utilities for action-chunking policies.""" + +from lerobot.policies.rtc.action_interpolator import ActionInterpolator +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.policies.rtc.modeling_rtc import RTCProcessor + +__all__ = [ + "ActionInterpolator", + "ActionQueue", + "LatencyTracker", + "RTCConfig", + "RTCProcessor", +] + diff --git a/src/lerobot/policies/rtc/action_interpolator.py b/src/lerobot/policies/rtc/action_interpolator.py new file mode 100644 index 000000000..50be16e08 --- /dev/null +++ b/src/lerobot/policies/rtc/action_interpolator.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Action interpolation for smoother robot control. + +Provides configurable Nx control rate by interpolating between consecutive actions. +Useful with RTC and action-chunking policies to reduce jerkiness. +""" + +from torch import Tensor + + +class ActionInterpolator: + """Interpolates between consecutive actions for smoother control. + + When enabled with multiplier N, produces N actions per policy action + by linearly interpolating between the previous and current action. + + Example with multiplier=3: + prev_action -> [1/3 interpolated, 2/3 interpolated, current_action] + + This effectively multiplies the control rate for smoother motion. + + Usage: + interpolator = ActionInterpolator(multiplier=2) # 2x control rate + + # In control loop: + if interpolator.needs_new_action(): + new_action = queue.get() + if new_action: + interpolator.add(new_action.cpu()) + + action = interpolator.get() + if action: + robot.send_action(action) + """ + + def __init__(self, multiplier: int = 1): + """Initialize the interpolator. + + Args: + multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.) + """ + if multiplier < 1: + raise ValueError(f"multiplier must be >= 1, got {multiplier}") + self.multiplier = multiplier + self._prev: Tensor | None = None + self._buffer: list[Tensor] = [] + self._idx = 0 + + @property + def enabled(self) -> bool: + """Whether interpolation is active (multiplier > 1).""" + return self.multiplier > 1 + + def reset(self): + """Reset interpolation state (call between episodes).""" + self._prev = None + self._buffer = [] + self._idx = 0 + + def needs_new_action(self) -> bool: + """Check if a new action is needed from the queue.""" + return self._idx >= len(self._buffer) + + def add(self, action: Tensor) -> None: + """Add a new action and compute interpolated sequence. + + Args: + action: New action tensor from policy/queue (already on CPU). + """ + if self.multiplier > 1 and self._prev is not None: + self._buffer = [] + for i in range(1, self.multiplier + 1): + t = i / self.multiplier + interp = self._prev + t * (action - self._prev) + self._buffer.append(interp) + else: + self._buffer = [action] + self._prev = action + self._idx = 0 + + def get(self) -> Tensor | None: + """Get the next interpolated action. + + Returns: + Next action tensor, or None if buffer is exhausted. + """ + if self._idx >= len(self._buffer): + return None + action = self._buffer[self._idx] + self._idx += 1 + return action + + def get_control_interval(self, fps: float) -> float: + """Get the control interval based on interpolation multiplier. + + Args: + fps: Base frames per second. + + Returns: + Control interval in seconds (divided by multiplier). + """ + return 1.0 / (fps * self.multiplier) + diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 948e92bb8..79e0c8e33 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -65,6 +65,8 @@ from pathlib import Path from pprint import pformat from typing import Any +import torch + from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) @@ -79,6 +81,7 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc import ActionInterpolator from lerobot.policies.utils import make_robot_action from lerobot.processor import ( PolicyAction, @@ -189,6 +192,9 @@ class RecordConfig: play_sounds: bool = True # Resume recording on an existing dataset. resume: bool = False + # Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x) + # Only applies when using a policy (not teleop) + interpolation_multiplier: int = 1 def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. @@ -259,6 +265,7 @@ def record_loop( control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, + interpolator: ActionInterpolator | None = None, ): if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") @@ -294,6 +301,14 @@ def record_loop( preprocessor.reset() postprocessor.reset() + # Reset interpolator if provided + if interpolator is not None: + interpolator.reset() + + # Calculate control interval based on interpolation + use_interpolation = interpolator is not None and interpolator.enabled and policy is not None + control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps + timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: @@ -314,24 +329,58 @@ def record_loop( # Get action from either policy or teleop if policy is not None and preprocessor is not None and postprocessor is not None: - action_values = predict_action( - observation=observation_frame, - policy=policy, - device=get_safe_torch_device(policy.config.device), - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) + # With interpolation: only call policy when interpolator needs new action + if use_interpolation: + # Get action keys from robot + action_keys = sorted(robot.action_features) - act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) + if interpolator.needs_new_action(): + action_values = predict_action( + observation=observation_frame, + policy=policy, + device=get_safe_torch_device(policy.config.device), + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + act_processed_policy = make_robot_action(action_values, dataset.features) + robot_action_to_send = robot_action_processor((act_processed_policy, obs)) + + # Convert to tensor for interpolator + action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys]) + interpolator.add(action_tensor) + + # Get interpolated action + interp_action = interpolator.get() + if interp_action is not None: + robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)} + action_values = robot_action_to_send + else: + # No action available yet, skip this iteration + continue + else: + action_values = predict_action( + observation=observation_frame, + policy=policy, + device=get_safe_torch_device(policy.config.device), + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features) + robot_action_to_send = robot_action_processor((act_processed_policy, obs)) elif policy is None and isinstance(teleop, Teleoperator): act = teleop.get_action() # Applies a pipeline to the raw teleop action, default is IdentityProcessor act_processed_teleop = teleop_action_processor((act, obs)) + action_values = act_processed_teleop + robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) elif policy is None and isinstance(teleop, list): arm_action = teleop_arm.get_action() @@ -340,6 +389,8 @@ def record_loop( base_action = robot._from_keyboard_to_base_action(keyboard_action) act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action act_processed_teleop = teleop_action_processor((act, obs)) + action_values = act_processed_teleop + robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) else: logging.info( "No policy or teleoperator provided, skipping action generation." @@ -348,14 +399,6 @@ def record_loop( ) continue - # Applies a pipeline to the action, default is IdentityProcessor - if policy is not None and act_processed_policy is not None: - action_values = act_processed_policy - robot_action_to_send = robot_action_processor((act_processed_policy, obs)) - else: - action_values = act_processed_teleop - robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) - # Send action to robot # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. action = postprocessor.process(action) @@ -372,7 +415,7 @@ def record_loop( log_rerun_data(observation=obs_processed, action=action_values) dt_s = time.perf_counter() - start_loop_t - precise_sleep(1 / fps - dt_s) + precise_sleep(control_interval - dt_s) timestamp = time.perf_counter() - start_episode_t @@ -440,6 +483,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) preprocessor = None postprocessor = None + interpolator = None if cfg.policy is not None: preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, @@ -450,6 +494,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset: "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, }, ) + # Create interpolator for smoother policy control + if cfg.interpolation_multiplier > 1: + interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) + logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate") robot.connect() if teleop is not None: @@ -476,6 +524,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, + interpolator=interpolator, ) # Execute a few seconds without recording to give time to manually reset the environment