diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index 2c1b43cce..e20d2c01b 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -54,9 +54,11 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.latency_tracker import LatencyTracker +from lerobot.policies.utils import make_robot_action from lerobot.processor import ( IdentityProcessorStep, PolicyAction, @@ -72,12 +74,12 @@ from lerobot.processor.converters import ( transition_to_robot_action, ) from lerobot.processor.rename_processor import rename_stats -from lerobot.robots import RobotConfig, make_robot_from_config +from lerobot.robots import Robot, RobotConfig, make_robot_from_config from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401 -from lerobot.teleoperators import TeleoperatorConfig, make_teleoperator_from_config +from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401 from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.control_utils import is_headless +from lerobot.utils.control_utils import is_headless, predict_action from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data @@ -119,10 +121,10 @@ class RaCRTCConfig: policy: PreTrainedConfig | None = None rtc: RTCConfig = field(default_factory=lambda: RTCConfig( - enabled=True, - execution_horizon=10, - max_guidance_weight=10.0, - prefix_attention_schedule=RTCAttentionSchedule.EXP, + enabled=True, + execution_horizon=20, + max_guidance_weight=5.0, + prefix_attention_schedule=RTCAttentionSchedule.LINEAR, )) interpolation: bool = True @@ -290,138 +292,38 @@ def make_identity_processors(): return teleop_proc, robot_proc, obs_proc -class SharedState: - """Thread-safe shared state for RTC inference thread.""" - def __init__(self): - self.obs: dict | None = None - self.action_queue: ActionQueue | None = None - - -def rtc_inference_thread( - policy, - shared_state: SharedState, - shutdown_event: Event, - policy_active: Event, - cfg: RaCRTCConfig, - hw_features: dict, - preprocessor, - postprocessor, -): - """Background thread that generates action chunks using RTC. - - This thread: - - Waits for policy_active to be set - - Uses observation from shared_state.obs (set by main loop) - - Generates action chunks and puts them in shared_state.action_queue - """ - logger.info("[RTC] Inference thread started (waiting for policy_active signal)") - logger.info("[RTC] Thread is IDLE - will not do anything until main loop activates policy") - - latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / cfg.dataset.fps - policy_device = policy.config.device - - get_actions_threshold = cfg.action_queue_size_to_get_new_actions - if not cfg.rtc.enabled: - get_actions_threshold = 0 - - inference_count = 0 - was_active = False - - while not shutdown_event.is_set(): - if not policy_active.is_set(): - if was_active: - logger.info("[RTC] Policy deactivated, thread going idle") - was_active = False - time.sleep(0.01) - continue - - if not was_active: - logger.info("[RTC] Policy activated! Starting inference loop") - was_active = True - - action_queue = shared_state.action_queue - if action_queue is None: - time.sleep(0.01) - continue - - if action_queue.qsize() <= get_actions_threshold: - obs = shared_state.obs - if obs is None: - time.sleep(0.01) - continue - - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() - - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 - - obs_with_policy_features = build_dataset_frame(hw_features, obs, prefix="observation") - - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = ( - obs_with_policy_features[name].type(torch.float32) / 255 - ) - obs_with_policy_features[name] = ( - obs_with_policy_features[name].permute(2, 0, 1).contiguous() - ) - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) - obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) - - obs_with_policy_features["task"] = [cfg.dataset.single_task] - obs_with_policy_features["robot_type"] = "openarms_follower" - - preprocessed_obs = preprocessor(obs_with_policy_features) - - actions = policy.predict_action_chunk( - preprocessed_obs, - inference_delay=inference_delay, - prev_chunk_left_over=prev_actions, - ) - - original_actions = actions.squeeze(0).clone() - postprocessed_actions = postprocessor(actions).squeeze(0) - - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - latency_tracker.add(new_latency) - - action_queue.merge( - original_actions, postprocessed_actions, new_delay, action_index_before_inference - ) - - inference_count += 1 - if inference_count % 10 == 0: - logger.debug(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}") - else: - time.sleep(0.005) - - logger.info("[RTC] Inference thread shutting down") - - @safe_stop_image_writer def rac_rtc_rollout_loop( - robot, - teleop, - shared_state: SharedState, - policy_active: Event, + robot: Robot, + teleop: Teleoperator, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], dataset: LeRobotDataset, events: dict, - cfg: RaCRTCConfig, - action_keys: list[str], + fps: int, + control_time_s: float, + single_task: str, + display_data: bool = True, + use_rtc: bool = True, + rtc_config: RTCConfig | None = None, + interpolation: bool = False, + device: str = "cuda", ) -> dict: - """RaC rollout loop with RTC for smooth policy execution.""" - logger.info("[ROLLOUT] Starting rollout loop...") - - fps = cfg.dataset.fps - single_task = cfg.dataset.single_task - control_time_s = cfg.dataset.episode_time_s + """ + RaC rollout loop with optional RTC for smooth policy execution. + Matches the original rac_data_collection_openarms.py structure exactly, + but uses RTC action queue for smoother motion when use_rtc=True. + """ + # Reset policy and processors - EXACTLY like original + policy.reset() + preprocessor.reset() + postprocessor.reset() + + device = get_safe_torch_device(device) frame_buffer = [] + stats = { "total_frames": 0, "autonomous_frames": 0, @@ -429,33 +331,28 @@ def rac_rtc_rollout_loop( "correction_frames": 0, } + # Start with teleop torque disabled - EXACTLY like original teleop.disable_torque() was_paused = False waiting_for_takeover = False - - # Interpolation state + + # RTC state (only used when use_rtc=True) + action_queue = None + latency_tracker = None + time_per_chunk = 1.0 / fps prev_action: Tensor | None = None interpolated_actions: list[Tensor] = [] interp_idx = 0 + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - if cfg.interpolation: - interp_factor = 2 - control_interval = 1.0 / (fps * interp_factor) - logger.info(f"[ROLLOUT] Interpolation ON: {fps}Hz -> {fps * interp_factor}Hz") - else: - interp_factor = 1 - control_interval = 1.0 / fps - logger.info(f"[ROLLOUT] Interpolation OFF: {fps}Hz") - - # Hz tracking - robot_send_count = 0 - policy_consume_count = 0 - last_hz_time = time.perf_counter() - last_record_time = 0.0 - + if use_rtc and rtc_config: + action_queue = ActionQueue(rtc_config) + latency_tracker = LatencyTracker() + get_actions_threshold = 30 if rtc_config.enabled else 0 + timestamp = 0 start_t = time.perf_counter() - first_iteration = True + robot_action = {} # Initialize for log_rerun_data while timestamp < control_time_s: loop_start = time.perf_counter() @@ -466,21 +363,10 @@ def rac_rtc_rollout_loop( events["correction_active"] = False break - # Get observation (always from main thread - only place robot is read) - if first_iteration: - logger.info("[ROLLOUT] First iteration - reading observation from robot...") - obs = robot.get_observation() - if first_iteration: - logger.info("[ROLLOUT] First observation read OK") - first_iteration = False - obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} - - # Update shared observation for RTC thread - shared_state.obs = obs_filtered - - # State transition: entering paused state + # Detect transition to paused state - EXACTLY like original if events["policy_paused"] and not was_paused: - policy_active.clear() # Stop RTC inference + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} print("[RaC] Moving teleop to robot position (2s smooth transition)...") teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) @@ -493,7 +379,7 @@ def rac_rtc_rollout_loop( interpolated_actions = [] interp_idx = 0 - # Wait for start button before enabling correction mode + # Wait for start button - EXACTLY like original if waiting_for_takeover and events["start_next_episode"]: print("[RaC] Start pressed - enabling teleop control...") teleop.disable_torque() @@ -501,98 +387,130 @@ def rac_rtc_rollout_loop( events["correction_active"] = True waiting_for_takeover = False + # Get observation - EXACTLY like original + obs = robot.get_observation() + obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) if events["correction_active"]: - # Human controlling - record correction data + # Human controlling - EXACTLY like original robot_action = teleop.get_action() for key in robot_action: if "gripper" in key: robot_action[key] = -0.65 * robot_action[key] robot.send_action(robot_action) - robot_send_count += 1 stats["correction_frames"] += 1 + # Record this frame action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": single_task} frame_buffer.append(frame) stats["total_frames"] += 1 elif waiting_for_takeover: - # Waiting for START - policy stopped, no recording + # Waiting for START - EXACTLY like original (no action sent to robot!) stats["paused_frames"] += 1 elif events["policy_paused"]: - # Paused - teleop tracks robot position + # Paused - teleop tracks robot - EXACTLY like original robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} teleop.send_feedback(robot_pos) stats["paused_frames"] += 1 else: - # Policy execution with RTC - policy_active.set() - action_queue = shared_state.action_queue - - # Get next action from queue (with interpolation) - if interp_idx >= len(interpolated_actions): - new_action = action_queue.get() if action_queue else None - if new_action is not None: - current_action = new_action.cpu() - policy_consume_count += 1 + # Policy execution - use RTC if enabled, otherwise original predict_action + if use_rtc and action_queue is not None: + # RTC path: check if we need to generate more actions + if action_queue.qsize() <= get_actions_threshold: + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() - if cfg.interpolation and prev_action is not None: + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + # Run inference - using predict_action for consistency with original + action_values = predict_action( + observation=obs_frame, + policy=policy, + device=device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + + new_latency = time.perf_counter() - current_time + latency_tracker.add(new_latency) + + # Get action from queue + queue_action = action_queue.get() + if queue_action is not None: + current_action = queue_action.cpu() if isinstance(queue_action, Tensor) else queue_action + + # Handle interpolation + if interpolation and prev_action is not None and isinstance(current_action, Tensor): mid = prev_action + 0.5 * (current_action - prev_action) interpolated_actions = [mid, current_action] else: interpolated_actions = [current_action] - prev_action = current_action + if isinstance(current_action, Tensor): + prev_action = current_action interp_idx = 0 - - if interp_idx < len(interpolated_actions): - action_to_send = interpolated_actions[interp_idx] - interp_idx += 1 - action_dict = {} - for i, key in enumerate(action_keys): - if i < len(action_to_send): - action_dict[key] = action_to_send[i].item() - - robot.send_action(action_dict) - robot_send_count += 1 - stats["autonomous_frames"] += 1 - - # Record at dataset fps (not interpolated rate) - now = time.perf_counter() - if now - last_record_time >= (1.0 / fps): - last_record_time = now - action_frame = build_dataset_frame(dataset.features, action_dict, prefix=ACTION) + # Send interpolated action + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 + + if isinstance(action_to_send, Tensor): + robot_action = {} + for i, key in enumerate(action_keys): + if i < len(action_to_send): + robot_action[key] = action_to_send[i].item() + else: + robot_action = action_to_send + + robot.send_action(robot_action) + stats["autonomous_frames"] += 1 + + # Record this frame + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": single_task} frame_buffer.append(frame) stats["total_frames"] += 1 + else: + # Original path - EXACTLY like original rac_data_collection_openarms.py + action_values = predict_action( + observation=obs_frame, + policy=policy, + device=device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + robot_action: RobotAction = make_robot_action(action_values, dataset.features) + robot.send_action(robot_action) + stats["autonomous_frames"] += 1 + + # Record this frame + action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) + frame = {**obs_frame, **action_frame, "task": single_task} + frame_buffer.append(frame) + stats["total_frames"] += 1 - # Print Hz stats every 5 seconds - now = time.perf_counter() - if now - last_hz_time >= 5.0: - elapsed = now - last_hz_time - actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0 - actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0 - mode = "CORRECTION" if events["correction_active"] else ("PAUSED" if events["policy_paused"] else "POLICY") - logger.info(f"[Hz] Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}, Mode: {mode}") - robot_send_count = 0 - policy_consume_count = 0 - last_hz_time = now - - if cfg.display_data: - log_rerun_data(observation=obs_filtered, action=action_dict if 'action_dict' in dir() else {}) + if display_data: + log_rerun_data(observation=obs_filtered, action=robot_action) dt = time.perf_counter() - loop_start - sleep_time = control_interval - dt - if sleep_time > 0: - precise_sleep(sleep_time) + precise_sleep(1 / fps - dt) timestamp = time.perf_counter() - start_t - policy_active.clear() + # Ensure teleoperator torque is disabled at end - EXACTLY like original teleop.disable_torque() for frame in frame_buffer: @@ -601,8 +519,12 @@ def rac_rtc_rollout_loop( return stats -def reset_loop(robot, teleop, events: dict, fps: int): - """Reset period where human repositions environment.""" +def reset_loop( + robot: Robot, + teleop: Teleoperator, + events: dict, + fps: int, +): print("\n" + "=" * 65) print(" [RaC] RESET - Moving teleop to robot position...") print("=" * 65) @@ -672,10 +594,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: dataset = None listener = None - shutdown_event = Event() - policy_active = Event() - shared_state = SharedState() - rtc_thread = None try: if cfg.resume: @@ -703,73 +621,47 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) - # Load policy - logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") - policy_class = get_policy_class(cfg.policy.type) - policy = policy_class.from_pretrained(cfg.policy.pretrained_path) - policy.config.rtc_config = cfg.rtc - policy.init_rtc_processor() - policy = policy.to(cfg.device) - policy.eval() - logger.info(f"Policy loaded: {policy.name}") + # Load policy - same as original + policy = None + preprocessor = None + postprocessor = None + + if cfg.policy: + logger.info(f"Loading policy from: {cfg.policy.pretrained_path}") + policy_class = get_policy_class(cfg.policy.type) + policy = policy_class.from_pretrained(cfg.policy.pretrained_path) + + # Setup RTC if enabled + if cfg.rtc.enabled: + policy.config.rtc_config = cfg.rtc + policy.init_rtc_processor() + + policy = policy.to(cfg.device) + policy.eval() + logger.info(f"Policy loaded: {policy.name}") - # Setup preprocessor/postprocessor for RTC thread - hw_features = hw_to_dataset_features(robot.observation_features, "observation") - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), - preprocessor_overrides={ - "device_processor": {"device": cfg.device}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, - }, - ) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) robot.connect() - logger.info("Robot connected, waiting for CAN bus to stabilize...") - time.sleep(1.0) # Let CAN bus stabilize - - # Test read to verify robot communication is working - logger.info("Testing robot communication...") - test_obs = robot.get_observation() - logger.info(f"Robot test read OK, got {len(test_obs)} observation keys") - time.sleep(0.5) - teleop.connect() listener, events = init_rac_keyboard_listener() - # Get action keys for the robot - action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - logger.info(f"Action keys: {action_keys}") - - # Start RTC inference thread (it will be idle until policy_active is set) - logger.info("Starting RTC inference thread (will be idle until episode starts)...") - rtc_thread = Thread( - target=rtc_inference_thread, - args=( - policy, - shared_state, - shutdown_event, - policy_active, - cfg, - hw_features, - preprocessor, - postprocessor, - ), - daemon=True, - name="RTCInference", - ) - rtc_thread.start() - logger.info("Started RTC inference thread") - print("\n" + "=" * 65) print(" RaC (Recovery and Correction) Data Collection with RTC") print("=" * 65) - print(f" Policy: {cfg.policy.pretrained_path}") + print(f" Policy: {cfg.policy.pretrained_path if cfg.policy else 'None'}") print(f" Task: {cfg.dataset.single_task}") print(f" RTC Enabled: {cfg.rtc.enabled}") print(f" Interpolation: {cfg.interpolation}") - print(f" Policy Hz: {cfg.dataset.fps}, Robot Hz: {cfg.dataset.fps * 2 if cfg.interpolation else cfg.dataset.fps}") + print(f" FPS: {cfg.dataset.fps}") print() print(" Controls:") print(" SPACE - Pause policy (teleop tracks robot, no recording)") @@ -784,10 +676,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) - # Create fresh action queue for this episode - shared_state.action_queue = ActionQueue(cfg.rtc) - shared_state.obs = None - logger.info(f"\n{'='*40}") logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") logger.info(f"{'='*40}") @@ -795,12 +683,19 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: stats = rac_rtc_rollout_loop( robot=robot, teleop=teleop, - shared_state=shared_state, - policy_active=policy_active, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, events=events, - cfg=cfg, - action_keys=action_keys, + fps=cfg.dataset.fps, + control_time_s=cfg.dataset.episode_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + use_rtc=cfg.rtc.enabled, + rtc_config=cfg.rtc, + interpolation=cfg.interpolation, + device=cfg.device, ) logging.info(f"Episode stats: {stats}") @@ -825,13 +720,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: finally: log_say("Stop recording", cfg.play_sounds, blocking=True) - - shutdown_event.set() - policy_active.clear() - - if rtc_thread and rtc_thread.is_alive(): - logger.info("Waiting for RTC thread to finish...") - rtc_thread.join(timeout=2.0) if dataset: dataset.finalize()