diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index 4e28ef2b9..a707c9277 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -1,45 +1,44 @@ #!/usr/bin/env python """ -RaC (Recovery and Correction) Data Collection with RTC for OpenArms Robot. +RaC (Recovery and Correction) Data Collection for OpenArms Robot with RTC. -Combines RaC paradigm with Real-Time Chunking (RTC) for smooth policy execution. -RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive -motion despite high inference latency by asynchronously generating action chunks. +This combines RaC data collection with Real-Time Chunking (RTC) for smooth policy execution. +RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive motion +despite high inference latency by asynchronously generating action chunks. The workflow: -1. Policy runs via RTC (async action generation) - teleop is idle -2. Press SPACE/right pedal to pause - teleop moves to match robot position -3. Press 'c'/left pedal to take control - human provides RECOVERY + CORRECTION -4. Press →/right pedal to end episode (save and continue to next) +1. Policy runs autonomously with RTC (teleop is idle/free) +2. Press SPACE to pause - teleop moves to match robot position +3. Press 'c' to take control - teleop is free, human provides RECOVERY + CORRECTION +4. Press → to end episode (save and continue to next) +5. Reset, then do next rollout -Controls: - SPACE/Right pedal - Pause policy - c/Left pedal - Take control (start correction) - →/Right pedal - End episode (when in correction mode) - ESC - Stop recording and push to hub +Keyboard Controls: + SPACE - Pause policy (teleop mirrors robot, no recording) + c - Take control (teleop free, recording correction) + → - End episode (save and continue to next) + ← - Re-record episode + ESC - Stop recording and push dataset to hub Usage: python examples/rac/rac_data_collection_openarms_rtc.py \ - --robot.type=openarms_follower \ --robot.port_right=can0 \ --robot.port_left=can1 \ - --robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0}, right_wrist: {type: opencv, index_or_path: 2}}" \ - --teleop.type=openarms_mini \ --teleop.port_right=/dev/ttyUSB0 \ --teleop.port_left=/dev/ttyUSB1 \ --policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \ - --dataset.repo_id=my_user/rac_rtc_dataset \ + --dataset.repo_id=my_user/rac_openarms_dataset \ --dataset.single_task="Pick up the cube" """ import logging import math -import sys import time -import traceback from dataclasses import dataclass, field from pathlib import Path -from threading import Event, Lock, Thread +from pprint import pformat +from threading import Event, Thread +from typing import Any import torch from torch import Tensor @@ -49,91 +48,89 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import RTCAttentionSchedule +from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features +from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.latency_tracker import LatencyTracker -from lerobot.processor import make_default_processors -from lerobot.robots import Robot, RobotConfig, make_robot_from_config +from lerobot.processor import ( + IdentityProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RobotAction, + RobotObservation, + RobotProcessorPipeline, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.processor.rename_processor import rename_stats +from lerobot.robots import RobotConfig, make_robot_from_config from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401 from lerobot.teleoperators import TeleoperatorConfig, make_teleoperator_from_config from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401 +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import is_headless -from lerobot.utils.hub import HubMixin from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import init_logging, log_say +from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class RobotWrapper: - """Thread-safe wrapper for robot operations.""" - - def __init__(self, robot: Robot): - self.robot = robot - self.lock = Lock() - - def get_observation(self) -> dict[str, Tensor]: - with self.lock: - return self.robot.get_observation() - - def send_action(self, action: dict) -> None: - with self.lock: - self.robot.send_action(action) - - @property - def observation_features(self) -> dict: - return self.robot.observation_features - - @property - def action_features(self) -> dict: - return self.robot.action_features - - @property - def name(self) -> str: - return self.robot.name @dataclass class RaCRTCDatasetConfig: - """Dataset configuration for RaC + RTC.""" - repo_id: str = "lerobot/rac_rtc_openarms" - single_task: str = "task" + repo_id: str = "lerobot/rac_openarms_rtc" + single_task: str = "default task" root: str | Path | None = None fps: int = 30 episode_time_s: float = 120 + reset_time_s: float = 30 num_episodes: int = 50 video: bool = True push_to_hub: bool = True - private: bool = True + private: bool = False + tags: list[str] | None = None num_image_writer_processes: int = 0 num_image_writer_threads_per_camera: int = 4 + video_encoding_batch_size: int = 1 + rename_map: dict[str, str] = field(default_factory=dict) @dataclass -class RaCRTCConfig(HubMixin): - """Configuration for RaC data collection with RTC.""" - - robot: RobotConfig - teleop: TeleoperatorConfig - dataset: RaCRTCDatasetConfig +class RaCRTCConfig: + robot: RobotConfig = field(default_factory=lambda: OpenArmsFollowerConfig( + port_left="can0", + port_right="can1", + )) + teleop: TeleoperatorConfig = field(default_factory=lambda: OpenArmsMiniConfig( + port_left="/dev/ttyUSB1", + port_right="/dev/ttyUSB0", + )) + dataset: RaCRTCDatasetConfig = field(default_factory=RaCRTCDatasetConfig) policy: PreTrainedConfig | None = None - - rtc: RTCConfig = field( - default_factory=lambda: RTCConfig( - enabled=True, - execution_horizon=20, - max_guidance_weight=5.0, - prefix_attention_schedule=RTCAttentionSchedule.LINEAR, - ) - ) - + + rtc: RTCConfig = field(default_factory=lambda: RTCConfig( + enabled=True, + execution_horizon=10, + max_guidance_weight=10.0, + prefix_attention_schedule=RTCAttentionSchedule.EXP, + )) + + interpolation: bool = True + display_data: bool = True + play_sounds: bool = True + resume: bool = False device: str = "cuda" action_queue_size_to_get_new_actions: int = 30 - interpolation: bool = True - play_sounds: bool = True def __post_init__(self): policy_path = parser.get_path_arg("policy") @@ -149,11 +146,21 @@ class RaCRTCConfig(HubMixin): return ["policy"] -def init_keyboard_listener(events: dict): - """Initialize keyboard listener with RaC controls.""" +def init_rac_keyboard_listener(): + """Initialize keyboard listener with RaC-specific controls.""" + events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + "policy_paused": False, + "correction_active": False, + "in_reset": False, + "start_next_episode": False, + } + if is_headless(): - logger.warning("Headless environment - keyboard controls unavailable") - return None + logging.warning("Headless environment - keyboard controls unavailable") + return None, events from pynput import keyboard @@ -161,29 +168,34 @@ def init_keyboard_listener(events: dict): try: if events["in_reset"]: if key == keyboard.Key.space or key == keyboard.Key.right: + print("\n[RaC] Starting next episode...") events["start_next_episode"] = True - elif hasattr(key, "char") and key.char == "c": + elif hasattr(key, 'char') and key.char == 'c': + print("\n[RaC] Starting next episode...") events["start_next_episode"] = True elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") events["stop_recording"] = True events["start_next_episode"] = True else: if key == keyboard.Key.space: if not events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ⏸ PAUSED - Press 'c' to take control") + print("\n[RaC] ⏸ PAUSED - Policy stopped, teleop moving to robot position") + print(" Press 'c' or START to take control") events["policy_paused"] = True - elif hasattr(key, "char") and key.char == "c": + elif hasattr(key, 'char') and key.char == 'c': if events["policy_paused"] and not events["correction_active"]: - print("\n[RaC] ▶ Taking control...") - events["start_correction"] = True + print("\n[RaC] ▶ START pressed - taking control") + events["start_next_episode"] = True elif key == keyboard.Key.right: - print("\n[RaC] → End episode") + print("[RaC] → End episode") events["exit_early"] = True elif key == keyboard.Key.left: - print("\n[RaC] ← Re-record episode") + print("[RaC] ← Re-record episode") events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.esc: + print("[RaC] ESC - Stop recording, pushing to hub...") events["stop_recording"] = True events["exit_early"] = True except Exception as e: @@ -191,469 +203,630 @@ def init_keyboard_listener(events: dict): listener = keyboard.Listener(on_press=on_press) listener.start() + start_pedal_listener(events) - return listener + + return listener, events def start_pedal_listener(events: dict): - """Start foot pedal listener if available.""" + """Start foot pedal listener thread if evdev is available.""" import threading - + try: - from evdev import InputDevice, ecodes + from evdev import InputDevice, ecodes # noqa: F401 except ImportError: + logging.info("[Pedal] evdev not installed - pedal support disabled") return - + PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" KEY_LEFT = "KEY_A" KEY_RIGHT = "KEY_C" - + def pedal_reader(): try: dev = InputDevice(PEDAL_DEVICE) print(f"[Pedal] Connected: {dev.name}") - + print(f"[Pedal] Right=pause/next, Left=take control/start") + for ev in dev.read_loop(): if ev.type != ecodes.EV_KEY: continue - - from evdev import categorize - + + from evdev import categorize # noqa: F401 key = categorize(ev) code = key.keycode if isinstance(code, (list, tuple)): code = code[0] - + if key.keystate != 1: continue - + if events["in_reset"]: if code in [KEY_LEFT, KEY_RIGHT]: + print("\n[Pedal] Starting next episode...") events["start_next_episode"] = True else: if code == KEY_RIGHT: if events["correction_active"]: + print("\n[Pedal] → End episode") events["exit_early"] = True elif not events["policy_paused"]: - print("\n[Pedal] ⏸ PAUSED") + print("\n[Pedal] ⏸ PAUSED - Policy stopped") events["policy_paused"] = True elif code == KEY_LEFT: if events["policy_paused"] and not events["correction_active"]: - print("\n[Pedal] ▶ Taking control...") - events["start_correction"] = True - + print("\n[Pedal] ▶ START pressed - taking control") + events["start_next_episode"] = True + except FileNotFoundError: - pass + logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") except PermissionError: - logger.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}") - except Exception: - pass - + logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}") + except Exception as e: + logging.debug(f"[Pedal] Error: {e}") + thread = threading.Thread(target=pedal_reader, daemon=True) thread.start() -def get_actions_thread( +def make_identity_processors(): + """Create identity processors for RaC recording.""" + teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + robot_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[IdentityProcessorStep()], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + return teleop_proc, robot_proc, obs_proc + + +class SharedState: + """Thread-safe shared state for RTC inference thread.""" + def __init__(self): + self.obs: dict | None = None + self.action_queue: ActionQueue | None = None + + +def rtc_inference_thread( policy, - robot: RobotWrapper, - robot_observation_processor, - action_queue_holder: dict, - obs_holder: dict, + shared_state: SharedState, shutdown_event: Event, - cfg: RaCRTCConfig, policy_active: Event, - fps: int, + cfg: RaCRTCConfig, + hw_features: dict, + preprocessor, + postprocessor, ): - """Thread for async action generation via RTC. Only this thread reads from robot.""" - try: - logger.info("[GET_ACTIONS] Starting RTC action generation thread") - - latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / fps - - hw_features = hw_to_dataset_features(robot.observation_features, "observation") - policy_device = policy.config.device - - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=None, - preprocessor_overrides={"device_processor": {"device": cfg.device}}, - ) - - get_actions_threshold = cfg.action_queue_size_to_get_new_actions - if not cfg.rtc.enabled: - get_actions_threshold = 0 - - while not shutdown_event.is_set(): - if not policy_active.is_set(): + """Background thread that generates action chunks using RTC. + + This thread: + - Waits for policy_active to be set + - Uses observation from shared_state.obs (set by main loop) + - Generates action chunks and puts them in shared_state.action_queue + """ + logger.info("[RTC] Inference thread started") + + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.dataset.fps + policy_device = policy.config.device + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + inference_count = 0 + + while not shutdown_event.is_set(): + if not policy_active.is_set(): + time.sleep(0.01) + continue + + action_queue = shared_state.action_queue + if action_queue is None: + time.sleep(0.01) + continue + + if action_queue.qsize() <= get_actions_threshold: + obs = shared_state.obs + if obs is None: time.sleep(0.01) continue + + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + obs_with_policy_features = build_dataset_frame(hw_features, obs, prefix="observation") + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = ( + obs_with_policy_features[name].type(torch.float32) / 255 + ) + obs_with_policy_features[name] = ( + obs_with_policy_features[name].permute(2, 0, 1).contiguous() + ) + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + obs_with_policy_features["task"] = [cfg.dataset.single_task] + obs_with_policy_features["robot_type"] = "openarms_follower" + + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + action_queue.merge( + original_actions, postprocessed_actions, new_delay, action_index_before_inference + ) + + inference_count += 1 + if inference_count % 10 == 0: + logger.debug(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}") + else: + time.sleep(0.005) + + logger.info("[RTC] Inference thread shutting down") - action_queue = action_queue_holder["queue"] - if action_queue.qsize() <= get_actions_threshold: - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 +@safe_stop_image_writer +def rac_rtc_rollout_loop( + robot, + teleop, + shared_state: SharedState, + policy_active: Event, + dataset: LeRobotDataset, + events: dict, + cfg: RaCRTCConfig, + action_keys: list[str], +) -> dict: + """RaC rollout loop with RTC for smooth policy execution.""" + fps = cfg.dataset.fps + single_task = cfg.dataset.single_task + control_time_s = cfg.dataset.episode_time_s + + frame_buffer = [] + stats = { + "total_frames": 0, + "autonomous_frames": 0, + "paused_frames": 0, + "correction_frames": 0, + } - obs = robot.get_observation() - obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} - obs_holder["obs"] = obs_filtered # Share with main loop for recording - obs_processed = robot_observation_processor(obs) + teleop.disable_torque() + was_paused = False + waiting_for_takeover = False + + # Interpolation state + prev_action: Tensor | None = None + interpolated_actions: list[Tensor] = [] + interp_idx = 0 + + if cfg.interpolation: + interp_factor = 2 + control_interval = 1.0 / (fps * interp_factor) + logger.info(f"[ROLLOUT] Interpolation ON: {fps}Hz -> {fps * interp_factor}Hz") + else: + interp_factor = 1 + control_interval = 1.0 / fps + logger.info(f"[ROLLOUT] Interpolation OFF: {fps}Hz") + + # Hz tracking + robot_send_count = 0 + policy_consume_count = 0 + last_hz_time = time.perf_counter() + last_record_time = 0.0 + + timestamp = 0 + start_t = time.perf_counter() - obs_with_policy_features = build_dataset_frame( - hw_features, obs_processed, prefix="observation" - ) + while timestamp < control_time_s: + loop_start = time.perf_counter() - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = ( - obs_with_policy_features[name].type(torch.float32) / 255 - ) - obs_with_policy_features[name] = ( - obs_with_policy_features[name].permute(2, 0, 1).contiguous() - ) - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) - obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + if events["exit_early"]: + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False + break - obs_with_policy_features["task"] = [cfg.dataset.single_task] - obs_with_policy_features["robot_type"] = robot.name + # Get observation (always from main thread) + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} + + # Update shared observation for RTC thread + shared_state.obs = obs_filtered - preprocessed_obs = preprocessor(obs_with_policy_features) + # State transition: entering paused state + if events["policy_paused"] and not was_paused: + policy_active.clear() # Stop RTC inference + robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} + print("[RaC] Moving teleop to robot position (2s smooth transition)...") + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + print("[RaC] Teleop aligned. Press START to take control.") + events["start_next_episode"] = False + waiting_for_takeover = True + was_paused = True + # Reset interpolation state + prev_action = None + interpolated_actions = [] + interp_idx = 0 - actions = policy.predict_action_chunk( - preprocessed_obs, - inference_delay=inference_delay, - prev_chunk_left_over=prev_actions, - ) + # Wait for start button before enabling correction mode + if waiting_for_takeover and events["start_next_episode"]: + print("[RaC] Start pressed - enabling teleop control...") + teleop.disable_torque() + events["start_next_episode"] = False + events["correction_active"] = True + waiting_for_takeover = False - original_actions = actions.squeeze(0).clone() - postprocessed_actions = postprocessor(actions).squeeze(0) + obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - latency_tracker.add(new_latency) + if events["correction_active"]: + # Human controlling - record correction data + robot_action = teleop.get_action() + for key in robot_action: + if "gripper" in key: + robot_action[key] = -0.65 * robot_action[key] + robot.send_action(robot_action) + robot_send_count += 1 + stats["correction_frames"] += 1 + + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 + + elif waiting_for_takeover: + # Waiting for START - policy stopped, no recording + stats["paused_frames"] += 1 + + elif events["policy_paused"]: + # Paused - teleop tracks robot position + robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} + teleop.send_feedback(robot_pos) + stats["paused_frames"] += 1 + + else: + # Policy execution with RTC + policy_active.set() + action_queue = shared_state.action_queue + + # Get next action from queue (with interpolation) + if interp_idx >= len(interpolated_actions): + new_action = action_queue.get() if action_queue else None + if new_action is not None: + current_action = new_action.cpu() + policy_consume_count += 1 + + if cfg.interpolation and prev_action is not None: + mid = prev_action + 0.5 * (current_action - prev_action) + interpolated_actions = [mid, current_action] + else: + interpolated_actions = [current_action] + + prev_action = current_action + interp_idx = 0 + + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 + + action_dict = {} + for i, key in enumerate(action_keys): + if i < len(action_to_send): + action_dict[key] = action_to_send[i].item() + + robot.send_action(action_dict) + robot_send_count += 1 + stats["autonomous_frames"] += 1 + + # Record at dataset fps (not interpolated rate) + now = time.perf_counter() + if now - last_record_time >= (1.0 / fps): + last_record_time = now + action_frame = build_dataset_frame(dataset.features, action_dict, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 - action_queue.merge( - original_actions, postprocessed_actions, new_delay, action_index_before_inference - ) - else: - time.sleep(0.01) + # Print Hz stats every 5 seconds + now = time.perf_counter() + if now - last_hz_time >= 5.0: + elapsed = now - last_hz_time + actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0 + actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0 + mode = "CORRECTION" if events["correction_active"] else ("PAUSED" if events["policy_paused"] else "POLICY") + logger.info(f"[Hz] Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}, Mode: {mode}") + robot_send_count = 0 + policy_consume_count = 0 + last_hz_time = now - logger.info("[GET_ACTIONS] Shutting down") - except Exception as e: - logger.error(f"[GET_ACTIONS] Fatal exception: {e}") - logger.error(traceback.format_exc()) - shutdown_event.set() - sys.exit(1) + if cfg.display_data: + log_rerun_data(observation=obs_filtered, action=action_dict if 'action_dict' in dir() else {}) + + dt = time.perf_counter() - loop_start + sleep_time = control_interval - dt + if sleep_time > 0: + precise_sleep(sleep_time) + timestamp = time.perf_counter() - start_t + + policy_active.clear() + teleop.disable_torque() + + for frame in frame_buffer: + dataset.add_frame(frame) + + return stats + + +def reset_loop(robot, teleop, events: dict, fps: int): + """Reset period where human repositions environment.""" + print("\n" + "=" * 65) + print(" [RaC] RESET - Moving teleop to robot position...") + print("=" * 65) + + events["in_reset"] = True + events["start_next_episode"] = False + + obs = robot.get_observation() + robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) + + print(" Teleop aligned. Press any key/pedal to enable teleoperation") + while not events["start_next_episode"] and not events["stop_recording"]: + precise_sleep(0.05) + + if events["stop_recording"]: + return + + events["start_next_episode"] = False + teleop.disable_torque() + print(" Teleop enabled - move robot to starting position") + print(" Press any key/pedal to start next episode") + + while not events["start_next_episode"] and not events["stop_recording"]: + loop_start = time.perf_counter() + action = teleop.get_action() + for key in action: + if "gripper" in key: + action[key] = -0.65 * action[key] + robot.send_action(action) + dt = time.perf_counter() - loop_start + precise_sleep(1 / fps - dt) + + events["in_reset"] = False + events["start_next_episode"] = False + events["exit_early"] = False + events["policy_paused"] = False + events["correction_active"] = False @parser.wrap() -def main(cfg: RaCRTCConfig): - """Main RaC + RTC data collection.""" +def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: + """Main RaC data collection function with RTC.""" init_logging() + logging.info(pformat(cfg.__dict__)) - fps = cfg.dataset.fps - - print("=" * 65) - print(" RaC Data Collection with RTC - OpenArms") - print("=" * 65) - print(f" Policy: {cfg.policy.pretrained_path}") - print(f" Dataset: {cfg.dataset.repo_id}") - print(f" Task: {cfg.dataset.single_task}") - print(f" Policy Hz: {fps}") - print(f" Robot Hz: {fps * 2 if cfg.interpolation else fps}") - print(f" Interpolation: {cfg.interpolation}") - print(f" RTC Enabled: {cfg.rtc.enabled}") - print("=" * 65) - - events = { - "exit_early": False, - "rerecord_episode": False, - "stop_recording": False, - "policy_paused": False, - "correction_active": False, - "start_correction": False, - "in_reset": False, - "start_next_episode": False, - } - - shutdown_event = Event() - policy_active = Event() - - follower = make_robot_from_config(cfg.robot) - follower.connect() - time.sleep(1) - robot = RobotWrapper(follower) - logger.info("Robot connected") - - # Test robot communication - test_obs = robot.get_observation() - logger.info(f"Robot test read OK, got {len([k for k in test_obs if k.endswith('.pos')])} motor positions") + if cfg.display_data: + init_rerun(session_name="rac_rtc_collection_openarms") + robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) - teleop.connect() - teleop.disable_torque() - logger.info("Teleop connected") - - teleop_proc, robot_proc, obs_proc = make_default_processors() - action_features_hw = {k: v for k, v in follower.action_features.items() if k.endswith(".pos")} + + teleop_proc, robot_proc, obs_proc = make_identity_processors() dataset_features = combine_feature_dicts( aggregate_pipeline_dataset_features( pipeline=teleop_proc, - initial_features=create_initial_features(action=action_features_hw), + initial_features=create_initial_features(action=robot.action_features), use_videos=cfg.dataset.video, ), aggregate_pipeline_dataset_features( pipeline=obs_proc, - initial_features=create_initial_features(observation=follower.observation_features), + initial_features=create_initial_features(observation=robot.observation_features), use_videos=cfg.dataset.video, ), ) - dataset = LeRobotDataset.create( - repo_id=cfg.dataset.repo_id, - fps=fps, - root=cfg.dataset.root, - features=dataset_features, - robot_type=follower.name, - use_videos=cfg.dataset.video, - image_writer_processes=cfg.dataset.num_image_writer_processes, - image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera - * len(follower.cameras if hasattr(follower, "cameras") else []), - ) - dataset_lock = Lock() - - policy_class = get_policy_class(cfg.policy.type) - config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) - policy.config.rtc_config = cfg.rtc - policy.init_rtc_processor() - policy = policy.to(cfg.device) - policy.eval() - logger.info(f"Policy loaded: {policy.name}") - - action_queue_holder = {"queue": ActionQueue(cfg.rtc)} - obs_holder = {"obs": None} - - get_actions_t = Thread( - target=get_actions_thread, - args=(policy, robot, obs_proc, action_queue_holder, obs_holder, shutdown_event, cfg, policy_active, fps), - daemon=True, - name="GetActions", - ) - get_actions_t.start() - - listener = init_keyboard_listener(events) - - print("\n Controls:") - print(" SPACE/Right pedal - Pause policy") - print(" c/Left pedal - Take control") - print(" →/Right pedal - End episode (in correction mode)") - print(" ESC - Stop and push to hub") - print("=" * 65 + "\n") - - action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - - if cfg.interpolation: - interp_factor = 2 - robot_interval = 1.0 / (fps * interp_factor) - else: - interp_factor = 1 - robot_interval = 1.0 / fps + dataset = None + listener = None + shutdown_event = Event() + policy_active = Event() + shared_state = SharedState() + rtc_thread = None try: - recorded = 0 - while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"RaC episode {recorded + 1}", play_sounds=cfg.play_sounds) + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) + if hasattr(robot, "cameras") and robot.cameras: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + else: + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera + * len(robot.cameras if hasattr(robot, "cameras") else []), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, + ) - action_queue_holder["queue"] = ActionQueue(cfg.rtc) - action_queue = action_queue_holder["queue"] - events["policy_paused"] = False - events["correction_active"] = False - events["start_correction"] = False - events["exit_early"] = False + # Load policy + logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") + policy_class = get_policy_class(cfg.policy.type) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path) + policy.config.rtc_config = cfg.rtc + policy.init_rtc_processor() + policy = policy.to(cfg.device) + policy.eval() + logger.info(f"Policy loaded: {policy.name}") - frame_buffer = [] - prev_action: Tensor | None = None - interpolated_actions: list[Tensor] = [] - interp_idx = 0 + # Setup preprocessor/postprocessor for RTC thread + hw_features = hw_to_dataset_features(robot.observation_features, "observation") + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) - robot_send_count = 0 - policy_consume_count = 0 - last_hz_print = time.perf_counter() + robot.connect() + teleop.connect() + listener, events = init_rac_keyboard_listener() - policy_active.set() - time.sleep(0.2) # Let RTC thread warm up - episode_start = time.perf_counter() + # Get action keys for the robot + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - while (time.perf_counter() - episode_start) < cfg.dataset.episode_time_s: - loop_start = time.perf_counter() + # Start RTC inference thread + rtc_thread = Thread( + target=rtc_inference_thread, + args=( + policy, + shared_state, + shutdown_event, + policy_active, + cfg, + hw_features, + preprocessor, + postprocessor, + ), + daemon=True, + name="RTCInference", + ) + rtc_thread.start() + logger.info("Started RTC inference thread") - if events["exit_early"]: - break + print("\n" + "=" * 65) + print(" RaC (Recovery and Correction) Data Collection with RTC") + print("=" * 65) + print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Task: {cfg.dataset.single_task}") + print(f" RTC Enabled: {cfg.rtc.enabled}") + print(f" Interpolation: {cfg.interpolation}") + print(f" Policy Hz: {cfg.dataset.fps}, Robot Hz: {cfg.dataset.fps * 2 if cfg.interpolation else cfg.dataset.fps}") + print() + print(" Controls:") + print(" SPACE - Pause policy (teleop tracks robot, no recording)") + print(" c - Take control (start correction, recording)") + print(" → - End episode (save)") + print(" ← - Re-record episode") + print(" ESC - Stop session and push to hub") + print("=" * 65 + "\n") - if events["start_correction"] and not events["correction_active"]: - policy_active.clear() - time.sleep(0.1) # Let RTC thread stop - print("[RaC] Moving teleop to robot position...") - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - teleop.disable_torque() - events["correction_active"] = True - events["start_correction"] = False - print("[RaC] Correction mode - you have control") + with VideoEncodingManager(dataset): + recorded = 0 + while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) + + # Create fresh action queue for this episode + shared_state.action_queue = ActionQueue(cfg.rtc) + shared_state.obs = None + + logger.info(f"\n{'='*40}") + logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") + logger.info(f"{'='*40}") - if events["correction_active"]: - # Human controlling - RTC thread is paused, we read robot directly - obs = robot.get_observation() - obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} - obs_frame = build_dataset_frame(dataset_features, obs_filtered, prefix="observation") + stats = rac_rtc_rollout_loop( + robot=robot, + teleop=teleop, + shared_state=shared_state, + policy_active=policy_active, + dataset=dataset, + events=events, + cfg=cfg, + action_keys=action_keys, + ) - robot_action = teleop.get_action() - for key in robot_action: - if "gripper" in key: - robot_action[key] = -0.65 * robot_action[key] - robot.send_action(robot_action) + logging.info(f"Episode stats: {stats}") - action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action") - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - frame_buffer.append(frame) + if events["rerecord_episode"]: + log_say("Re-recording", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - elif events["policy_paused"]: - pass + dataset.save_episode() + recorded += 1 - else: - # Policy running - RTC thread reads robot, we use shared observation - if interp_idx >= len(interpolated_actions): - new_action = action_queue.get() - if new_action is not None: - current_action = new_action.cpu() - policy_consume_count += 1 - - if cfg.interpolation and prev_action is not None: - mid = prev_action + 0.5 * (current_action - prev_action) - interpolated_actions = [mid, current_action] - else: - interpolated_actions = [current_action] - - prev_action = current_action - interp_idx = 0 - - if interp_idx < len(interpolated_actions): - action_to_send = interpolated_actions[interp_idx] - interp_idx += 1 - - action_dict = {} - for i, key in enumerate(action_keys): - if i < len(action_to_send): - action_dict[key] = action_to_send[i].item() - - action_processed = robot_proc((action_dict, None)) - robot.send_action(action_processed) - robot_send_count += 1 - - # Record using observation from RTC thread - if obs_holder["obs"] is not None: - obs_frame = build_dataset_frame(dataset_features, obs_holder["obs"], prefix="observation") - action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action") - frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task} - frame_buffer.append(frame) - - now = time.perf_counter() - if now - last_hz_print >= 5.0: - elapsed = now - last_hz_print - actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0 - actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0 - logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}") - robot_send_count = 0 - policy_consume_count = 0 - last_hz_print = now - - dt = time.perf_counter() - loop_start - sleep_time = max(0, robot_interval - dt - 0.001) - if sleep_time > 0: - precise_sleep(sleep_time) - - policy_active.clear() - - if events["rerecord_episode"]: - log_say("Re-recording", play_sounds=cfg.play_sounds) - events["rerecord_episode"] = False - continue - - for frame in frame_buffer: - dataset.add_frame(frame) - - with dataset_lock: - if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0: - logger.info(f"Saving episode ({dataset.episode_buffer['size']} frames)") - dataset.save_episode() - - recorded += 1 - - if recorded < cfg.num_episodes and not events["stop_recording"]: - events["in_reset"] = True - events["start_next_episode"] = False - print("\n[RaC] RESET - Press any key/pedal to enable teleop") - - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} - teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - - while not events["start_next_episode"] and not events["stop_recording"]: - precise_sleep(0.05) - - if events["stop_recording"]: - break - - events["start_next_episode"] = False - teleop.disable_torque() - print("[RaC] Teleop enabled - move to start, then press key/pedal") - - while not events["start_next_episode"] and not events["stop_recording"]: - action = teleop.get_action() - for key in action: - if "gripper" in key: - action[key] = -0.65 * action[key] - robot.send_action(action) - precise_sleep(1 / fps) - - events["in_reset"] = False - events["start_next_episode"] = False - - log_say("Recording complete", play_sounds=cfg.play_sounds, blocking=True) - - except KeyboardInterrupt: - logger.info("Interrupted by user") + if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: + reset_loop( + robot=robot, + teleop=teleop, + events=events, + fps=cfg.dataset.fps, + ) finally: + log_say("Stop recording", cfg.play_sounds, blocking=True) + shutdown_event.set() policy_active.clear() + + if rtc_thread and rtc_thread.is_alive(): + logger.info("Waiting for RTC thread to finish...") + rtc_thread.join(timeout=2.0) - if get_actions_t.is_alive(): - get_actions_t.join(timeout=5.0) + if dataset: + dataset.finalize() - follower.disconnect() - teleop.disconnect() + if robot.is_connected: + robot.disconnect() + if teleop.is_connected: + teleop.disconnect() - if listener: + if not is_headless() and listener: listener.stop() - dataset.finalize() if cfg.dataset.push_to_hub: - logger.info("Pushing to hub...") - dataset.push_to_hub(private=cfg.dataset.private) + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) - logger.info("Done") + return dataset + + +def main(): + from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() + rac_rtc_collect() if __name__ == "__main__": main() -