From 2d1fb0f50830c60187822711e22076f628618de3 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 9 Jan 2026 16:41:59 +0100 Subject: [PATCH] refactor --- .../rac/rac_data_collection_openarms_rtc.py | 512 +++++++++++------- 1 file changed, 309 insertions(+), 203 deletions(-) diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index e20d2c01b..4f54f6968 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -13,13 +13,6 @@ The workflow: 4. Press → to end episode (save and continue to next) 5. Reset, then do next rollout -Keyboard Controls: - SPACE - Pause policy (teleop mirrors robot, no recording) - c - Take control (teleop free, recording correction) - → - End episode (save and continue to next) - ← - Re-record episode - ESC - Stop recording and push dataset to hub - Usage: python examples/rac/rac_data_collection_openarms_rtc.py \ --robot.port_right=can0 \ @@ -37,7 +30,7 @@ import time from dataclasses import dataclass, field from pathlib import Path from pprint import pformat -from threading import Event, Thread +from threading import Event, Lock, Thread from typing import Any import torch @@ -88,6 +81,10 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# ============================================================================ +# Configuration +# ============================================================================ + @dataclass class RaCRTCDatasetConfig: repo_id: str = "lerobot/rac_openarms_rtc" @@ -148,6 +145,46 @@ class RaCRTCConfig: return ["policy"] +# ============================================================================ +# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py) +# ============================================================================ + +class RobotWrapper: + """Thread-safe wrapper for robot operations.""" + + def __init__(self, robot: Robot): + self.robot = robot + self.lock = Lock() + + def get_observation(self) -> dict[str, Tensor]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: dict) -> None: + with self.lock: + self.robot.send_action(action) + + @property + def observation_features(self) -> dict: + return self.robot.observation_features + + @property + def action_features(self) -> dict: + return self.robot.action_features + + @property + def name(self) -> str: + return self.robot.name + + @property + def robot_type(self) -> str: + return self.robot.robot_type + + +# ============================================================================ +# Keyboard/Pedal Listeners +# ============================================================================ + def init_rac_keyboard_listener(): """Initialize keyboard listener with RaC-specific controls.""" events = { @@ -229,7 +266,6 @@ def start_pedal_listener(events: dict): try: dev = InputDevice(PEDAL_DEVICE) print(f"[Pedal] Connected: {dev.name}") - print(f"[Pedal] Right=pause/next, Left=take control/start") for ev in dev.read_loop(): if ev.type != ecodes.EV_KEY: @@ -246,25 +282,21 @@ def start_pedal_listener(events: dict): if events["in_reset"]: if code in [KEY_LEFT, KEY_RIGHT]: - print("\n[Pedal] Starting next episode...") events["start_next_episode"] = True else: if code == KEY_RIGHT: if events["correction_active"]: - print("\n[Pedal] → End episode") events["exit_early"] = True elif not events["policy_paused"]: - print("\n[Pedal] ⏸ PAUSED - Policy stopped") events["policy_paused"] = True elif code == KEY_LEFT: if events["policy_paused"] and not events["correction_active"]: - print("\n[Pedal] ▶ START pressed - taking control") events["start_next_episode"] = True except FileNotFoundError: logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") except PermissionError: - logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}") + logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}") except Exception as e: logging.debug(f"[Pedal] Error: {e}") @@ -292,38 +324,139 @@ def make_identity_processors(): return teleop_proc, robot_proc, obs_proc +# ============================================================================ +# RTC Inference Thread (from evaluate_with_rtc.py) +# ============================================================================ + +def rtc_inference_thread( + policy, + obs_holder: dict, # {"obs": filtered_obs, "features": observation_features} - set by main loop + hw_features: dict, + preprocessor, + postprocessor, + queue_holder: dict, # {"queue": ActionQueue} - mutable so we can update per episode + shutdown_event: Event, + policy_active: Event, + cfg: RaCRTCConfig, +): + """Background thread that generates action chunks using RTC. + + IMPORTANT: This thread does NOT access the robot directly! + It reads observations from obs_holder which is updated by the main loop. + This avoids race conditions on the CAN bus. + """ + logger.info("[RTC] Inference thread started (reads obs from main loop, no direct robot access)") + + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.dataset.fps + policy_device = policy.config.device + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + while not shutdown_event.is_set(): + if not policy_active.is_set(): + time.sleep(0.01) + continue + + action_queue = queue_holder["queue"] + if action_queue is None: + time.sleep(0.01) + continue + + # Get observation from shared holder (set by main loop) + obs_filtered = obs_holder.get("obs") + if obs_filtered is None: + time.sleep(0.01) + continue + + if action_queue.qsize() <= get_actions_threshold: + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + # Build observation for policy (using obs from main loop) + obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") + + # Convert to tensors (like evaluate_with_rtc.py) + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = ( + obs_with_policy_features[name].type(torch.float32) / 255 + ) + obs_with_policy_features[name] = ( + obs_with_policy_features[name].permute(2, 0, 1).contiguous() + ) + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + obs_with_policy_features["task"] = [cfg.dataset.single_task] + obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") + + # Preprocess and run inference + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + # Put actions in queue! + action_queue.merge( + original_actions, postprocessed_actions, new_delay, action_index_before_inference + ) + + logger.debug(f"[RTC] Generated chunk, latency={new_latency:.2f}s, queue={action_queue.qsize()}") + else: + time.sleep(0.01) + + logger.info("[RTC] Inference thread shutting down") + + +# ============================================================================ +# Main Rollout Loop +# ============================================================================ + @safe_stop_image_writer def rac_rtc_rollout_loop( - robot: Robot, + robot: RobotWrapper, teleop: Teleoperator, policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], + preprocessor, + postprocessor, dataset: LeRobotDataset, events: dict, - fps: int, - control_time_s: float, - single_task: str, - display_data: bool = True, - use_rtc: bool = True, - rtc_config: RTCConfig | None = None, - interpolation: bool = False, - device: str = "cuda", + cfg: RaCRTCConfig, + queue_holder: dict, + obs_holder: dict, # Main loop writes obs here for RTC thread to read + policy_active: Event, + hw_features: dict, ) -> dict: - """ - RaC rollout loop with optional RTC for smooth policy execution. + """RaC rollout loop with RTC for smooth policy execution.""" + fps = cfg.dataset.fps + single_task = cfg.dataset.single_task + control_time_s = cfg.dataset.episode_time_s + device = get_safe_torch_device(cfg.device) - Matches the original rac_data_collection_openarms.py structure exactly, - but uses RTC action queue for smoother motion when use_rtc=True. - """ - # Reset policy and processors - EXACTLY like original + # Reset policy state policy.reset() preprocessor.reset() postprocessor.reset() - - device = get_safe_torch_device(device) + frame_buffer = [] - stats = { "total_frames": 0, "autonomous_frames": 0, @@ -331,28 +464,26 @@ def rac_rtc_rollout_loop( "correction_frames": 0, } - # Start with teleop torque disabled - EXACTLY like original teleop.disable_torque() was_paused = False waiting_for_takeover = False - - # RTC state (only used when use_rtc=True) - action_queue = None - latency_tracker = None - time_per_chunk = 1.0 / fps + + # Action keys for converting tensor to dict + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + + # Interpolation state prev_action: Tensor | None = None interpolated_actions: list[Tensor] = [] interp_idx = 0 - action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - if use_rtc and rtc_config: - action_queue = ActionQueue(rtc_config) - latency_tracker = LatencyTracker() - get_actions_threshold = 30 if rtc_config.enabled else 0 - + if cfg.interpolation: + control_interval = 1.0 / (fps * 2) # 2x rate + else: + control_interval = 1.0 / fps + + robot_action = {} timestamp = 0 start_t = time.perf_counter() - robot_action = {} # Initialize for log_rerun_data while timestamp < control_time_s: loop_start = time.perf_counter() @@ -363,37 +494,41 @@ def rac_rtc_rollout_loop( events["correction_active"] = False break - # Detect transition to paused state - EXACTLY like original + # State transition: entering paused state if events["policy_paused"] and not was_paused: + policy_active.clear() # Stop RTC inference obs = robot.get_observation() obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} - print("[RaC] Moving teleop to robot position (2s smooth transition)...") + print("[RaC] Moving teleop to robot position...") teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - print("[RaC] Teleop aligned. Press START to take control.") + print("[RaC] Teleop aligned. Press 'c' to take control.") events["start_next_episode"] = False waiting_for_takeover = True was_paused = True - # Reset interpolation state + # Reset interpolation prev_action = None interpolated_actions = [] interp_idx = 0 - # Wait for start button - EXACTLY like original + # Wait for takeover if waiting_for_takeover and events["start_next_episode"]: - print("[RaC] Start pressed - enabling teleop control...") + print("[RaC] Taking control...") teleop.disable_torque() events["start_next_episode"] = False events["correction_active"] = True waiting_for_takeover = False - # Get observation - EXACTLY like original + # Get observation (ONLY the main loop reads from robot!) obs = robot.get_observation() obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) + + # Share observation with RTC thread (thread reads, main loop writes) + obs_holder["obs"] = obs_filtered if events["correction_active"]: - # Human controlling - EXACTLY like original + # Human controlling robot_action = teleop.get_action() for key in robot_action: if "gripper" in key: @@ -401,116 +536,67 @@ def rac_rtc_rollout_loop( robot.send_action(robot_action) stats["correction_frames"] += 1 - # Record this frame action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": single_task} frame_buffer.append(frame) stats["total_frames"] += 1 elif waiting_for_takeover: - # Waiting for START - EXACTLY like original (no action sent to robot!) stats["paused_frames"] += 1 elif events["policy_paused"]: - # Paused - teleop tracks robot - EXACTLY like original robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} teleop.send_feedback(robot_pos) stats["paused_frames"] += 1 else: - # Policy execution - use RTC if enabled, otherwise original predict_action - if use_rtc and action_queue is not None: - # RTC path: check if we need to generate more actions - if action_queue.qsize() <= get_actions_threshold: - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() + # Policy execution with RTC + policy_active.set() + action_queue = queue_holder["queue"] + + # Get action from queue (with interpolation) + if interp_idx >= len(interpolated_actions): + new_action = action_queue.get() if action_queue else None + if new_action is not None: + current_action = new_action.cpu() - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 - - # Run inference - using predict_action for consistency with original - action_values = predict_action( - observation=obs_frame, - policy=policy, - device=device, - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) - - new_latency = time.perf_counter() - current_time - latency_tracker.add(new_latency) - - # Get action from queue - queue_action = action_queue.get() - if queue_action is not None: - current_action = queue_action.cpu() if isinstance(queue_action, Tensor) else queue_action - - # Handle interpolation - if interpolation and prev_action is not None and isinstance(current_action, Tensor): + if cfg.interpolation and prev_action is not None: mid = prev_action + 0.5 * (current_action - prev_action) interpolated_actions = [mid, current_action] else: interpolated_actions = [current_action] - if isinstance(current_action, Tensor): - prev_action = current_action + prev_action = current_action interp_idx = 0 + + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 + + robot_action = {} + for i, key in enumerate(action_keys): + if i < len(action_to_send): + robot_action[key] = action_to_send[i].item() - # Send interpolated action - if interp_idx < len(interpolated_actions): - action_to_send = interpolated_actions[interp_idx] - interp_idx += 1 - - if isinstance(action_to_send, Tensor): - robot_action = {} - for i, key in enumerate(action_keys): - if i < len(action_to_send): - robot_action[key] = action_to_send[i].item() - else: - robot_action = action_to_send - - robot.send_action(robot_action) - stats["autonomous_frames"] += 1 - - # Record this frame - action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) - frame = {**obs_frame, **action_frame, "task": single_task} - frame_buffer.append(frame) - stats["total_frames"] += 1 - else: - # Original path - EXACTLY like original rac_data_collection_openarms.py - action_values = predict_action( - observation=obs_frame, - policy=policy, - device=device, - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=single_task, - robot_type=robot.robot_type, - ) - robot_action: RobotAction = make_robot_action(action_values, dataset.features) robot.send_action(robot_action) stats["autonomous_frames"] += 1 - # Record this frame + # Record at original fps action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": single_task} frame_buffer.append(frame) stats["total_frames"] += 1 - if display_data: + if cfg.display_data: log_rerun_data(observation=obs_filtered, action=robot_action) dt = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt) + sleep_time = control_interval - dt + if sleep_time > 0: + precise_sleep(sleep_time) timestamp = time.perf_counter() - start_t - # Ensure teleoperator torque is disabled at end - EXACTLY like original + policy_active.clear() teleop.disable_torque() for frame in frame_buffer: @@ -519,14 +605,10 @@ def rac_rtc_rollout_loop( return stats -def reset_loop( - robot: Robot, - teleop: Teleoperator, - events: dict, - fps: int, -): +def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int): + """Reset period where human repositions environment.""" print("\n" + "=" * 65) - print(" [RaC] RESET - Moving teleop to robot position...") + print(" [RaC] RESET") print("=" * 65) events["in_reset"] = True @@ -536,7 +618,7 @@ def reset_loop( robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) - print(" Teleop aligned. Press any key/pedal to enable teleoperation") + print(" Press any key/pedal to enable teleoperation") while not events["start_next_episode"] and not events["stop_recording"]: precise_sleep(0.05) @@ -545,8 +627,7 @@ def reset_loop( events["start_next_episode"] = False teleop.disable_torque() - print(" Teleop enabled - move robot to starting position") - print(" Press any key/pedal to start next episode") + print(" Teleop enabled - press any key/pedal to start episode") while not events["start_next_episode"] and not events["stop_recording"]: loop_start = time.perf_counter() @@ -565,6 +646,10 @@ def reset_loop( events["correction_active"] = False +# ============================================================================ +# Main Entry Point +# ============================================================================ + @parser.wrap() def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: """Main RaC data collection function with RTC.""" @@ -574,7 +659,7 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: if cfg.display_data: init_rerun(session_name="rac_rtc_collection_openarms") - robot = make_robot_from_config(cfg.robot) + robot_raw = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) teleop_proc, robot_proc, obs_proc = make_identity_processors() @@ -582,18 +667,21 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: dataset_features = combine_feature_dicts( aggregate_pipeline_dataset_features( pipeline=teleop_proc, - initial_features=create_initial_features(action=robot.action_features), + initial_features=create_initial_features(action=robot_raw.action_features), use_videos=cfg.dataset.video, ), aggregate_pipeline_dataset_features( pipeline=obs_proc, - initial_features=create_initial_features(observation=robot.observation_features), + initial_features=create_initial_features(observation=robot_raw.observation_features), use_videos=cfg.dataset.video, ), ) dataset = None listener = None + shutdown_event = Event() + policy_active = Event() + rtc_thread = None try: if cfg.resume: @@ -602,73 +690,92 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) - if hasattr(robot, "cameras") and robot.cameras: + if hasattr(robot_raw, "cameras") and robot_raw.cameras: dataset.start_image_writer( num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras), ) else: dataset = LeRobotDataset.create( cfg.dataset.repo_id, cfg.dataset.fps, root=cfg.dataset.root, - robot_type=robot.name, + robot_type=robot_raw.name, features=dataset_features, use_videos=cfg.dataset.video, image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera - * len(robot.cameras if hasattr(robot, "cameras") else []), + * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []), batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) - # Load policy - same as original - policy = None - preprocessor = None - postprocessor = None + # Load policy + logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") + policy_class = get_policy_class(cfg.policy.type) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path) + policy.config.rtc_config = cfg.rtc + policy.init_rtc_processor() + policy = policy.to(cfg.device) + policy.eval() + logger.info(f"Policy loaded: {policy.name}") + + # Setup preprocessor/postprocessor + hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation") + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) + + # Connect robot and wrap for thread safety + robot_raw.connect() + robot = RobotWrapper(robot_raw) - if cfg.policy: - logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") - policy_class = get_policy_class(cfg.policy.type) - policy = policy_class.from_pretrained(cfg.policy.pretrained_path) - - # Setup RTC if enabled - if cfg.rtc.enabled: - policy.config.rtc_config = cfg.rtc - policy.init_rtc_processor() - - policy = policy.to(cfg.device) - policy.eval() - logger.info(f"Policy loaded: {policy.name}") - - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) - - robot.connect() teleop.connect() listener, events = init_rac_keyboard_listener() + # Shared state holders (main loop writes, RTC thread reads) + queue_holder = {"queue": ActionQueue(cfg.rtc)} + obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs + + # Start RTC inference thread + # NOTE: Thread does NOT access robot directly - reads from obs_holder + rtc_thread = Thread( + target=rtc_inference_thread, + args=( + policy, + obs_holder, # Thread reads obs from here (set by main loop) + hw_features, + preprocessor, + postprocessor, + queue_holder, + shutdown_event, + policy_active, + cfg, + ), + daemon=True, + name="RTCInference", + ) + rtc_thread.start() + logger.info("Started RTC inference thread") + print("\n" + "=" * 65) - print(" RaC (Recovery and Correction) Data Collection with RTC") + print(" RaC Data Collection with RTC") print("=" * 65) - print(f" Policy: {cfg.policy.pretrained_path if cfg.policy else 'None'}") + print(f" Policy: {cfg.policy.pretrained_path}") print(f" Task: {cfg.dataset.single_task}") - print(f" RTC Enabled: {cfg.rtc.enabled}") - print(f" Interpolation: {cfg.interpolation}") print(f" FPS: {cfg.dataset.fps}") + print(f" Interpolation: {cfg.interpolation}") print() print(" Controls:") - print(" SPACE - Pause policy (teleop tracks robot, no recording)") - print(" c - Take control (start correction, recording)") - print(" → - End episode (save)") - print(" ← - Re-record episode") - print(" ESC - Stop session and push to hub") + print(" SPACE - Pause policy") + print(" c - Take control") + print(" → - End episode") + print(" ESC - Stop and push to hub") print("=" * 65 + "\n") with VideoEncodingManager(dataset): @@ -676,9 +783,10 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) - logger.info(f"\n{'='*40}") + # Fresh action queue per episode (update holder so thread sees it) + queue_holder["queue"] = ActionQueue(cfg.rtc) + logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") - logger.info(f"{'='*40}") stats = rac_rtc_rollout_loop( robot=robot, @@ -688,14 +796,11 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: postprocessor=postprocessor, dataset=dataset, events=events, - fps=cfg.dataset.fps, - control_time_s=cfg.dataset.episode_time_s, - single_task=cfg.dataset.single_task, - display_data=cfg.display_data, - use_rtc=cfg.rtc.enabled, - rtc_config=cfg.rtc, - interpolation=cfg.interpolation, - device=cfg.device, + cfg=cfg, + queue_holder=queue_holder, + obs_holder=obs_holder, + policy_active=policy_active, + hw_features=hw_features, ) logging.info(f"Episode stats: {stats}") @@ -711,21 +816,22 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: recorded += 1 if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: - reset_loop( - robot=robot, - teleop=teleop, - events=events, - fps=cfg.dataset.fps, - ) + reset_loop(robot, teleop, events, cfg.dataset.fps) finally: log_say("Stop recording", cfg.play_sounds, blocking=True) + + shutdown_event.set() + policy_active.clear() + + if rtc_thread and rtc_thread.is_alive(): + rtc_thread.join(timeout=2.0) if dataset: dataset.finalize() - if robot.is_connected: - robot.disconnect() + if robot_raw.is_connected: + robot_raw.disconnect() if teleop.is_connected: teleop.disconnect()