This commit is contained in:
Pepijn
2026-01-09 16:41:59 +01:00
parent b1a55b0666
commit 2d1fb0f508
+309 -203
View File
@@ -13,13 +13,6 @@ The workflow:
4. Press → to end episode (save and continue to next) 4. Press → to end episode (save and continue to next)
5. Reset, then do next rollout 5. Reset, then do next rollout
Keyboard Controls:
SPACE - Pause policy (teleop mirrors robot, no recording)
c - Take control (teleop free, recording correction)
→ - End episode (save and continue to next)
← - Re-record episode
ESC - Stop recording and push dataset to hub
Usage: Usage:
python examples/rac/rac_data_collection_openarms_rtc.py \ python examples/rac/rac_data_collection_openarms_rtc.py \
--robot.port_right=can0 \ --robot.port_right=can0 \
@@ -37,7 +30,7 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from threading import Event, Thread from threading import Event, Lock, Thread
from typing import Any from typing import Any
import torch import torch
@@ -88,6 +81,10 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ============================================================================
# Configuration
# ============================================================================
@dataclass @dataclass
class RaCRTCDatasetConfig: class RaCRTCDatasetConfig:
repo_id: str = "lerobot/rac_openarms_rtc" repo_id: str = "lerobot/rac_openarms_rtc"
@@ -148,6 +145,46 @@ class RaCRTCConfig:
return ["policy"] return ["policy"]
# ============================================================================
# Thread-Safe Robot Wrapper (from evaluate_with_rtc.py)
# ============================================================================
class RobotWrapper:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: dict) -> None:
with self.lock:
self.robot.send_action(action)
@property
def observation_features(self) -> dict:
return self.robot.observation_features
@property
def action_features(self) -> dict:
return self.robot.action_features
@property
def name(self) -> str:
return self.robot.name
@property
def robot_type(self) -> str:
return self.robot.robot_type
# ============================================================================
# Keyboard/Pedal Listeners
# ============================================================================
def init_rac_keyboard_listener(): def init_rac_keyboard_listener():
"""Initialize keyboard listener with RaC-specific controls.""" """Initialize keyboard listener with RaC-specific controls."""
events = { events = {
@@ -229,7 +266,6 @@ def start_pedal_listener(events: dict):
try: try:
dev = InputDevice(PEDAL_DEVICE) dev = InputDevice(PEDAL_DEVICE)
print(f"[Pedal] Connected: {dev.name}") print(f"[Pedal] Connected: {dev.name}")
print(f"[Pedal] Right=pause/next, Left=take control/start")
for ev in dev.read_loop(): for ev in dev.read_loop():
if ev.type != ecodes.EV_KEY: if ev.type != ecodes.EV_KEY:
@@ -246,25 +282,21 @@ def start_pedal_listener(events: dict):
if events["in_reset"]: if events["in_reset"]:
if code in [KEY_LEFT, KEY_RIGHT]: if code in [KEY_LEFT, KEY_RIGHT]:
print("\n[Pedal] Starting next episode...")
events["start_next_episode"] = True events["start_next_episode"] = True
else: else:
if code == KEY_RIGHT: if code == KEY_RIGHT:
if events["correction_active"]: if events["correction_active"]:
print("\n[Pedal] → End episode")
events["exit_early"] = True events["exit_early"] = True
elif not events["policy_paused"]: elif not events["policy_paused"]:
print("\n[Pedal] ⏸ PAUSED - Policy stopped")
events["policy_paused"] = True events["policy_paused"] = True
elif code == KEY_LEFT: elif code == KEY_LEFT:
if events["policy_paused"] and not events["correction_active"]: if events["policy_paused"] and not events["correction_active"]:
print("\n[Pedal] ▶ START pressed - taking control")
events["start_next_episode"] = True events["start_next_episode"] = True
except FileNotFoundError: except FileNotFoundError:
logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}") logging.info(f"[Pedal] Device not found: {PEDAL_DEVICE}")
except PermissionError: except PermissionError:
logging.warning(f"[Pedal] Permission denied. Run: sudo setfacl -m u:$USER:rw {PEDAL_DEVICE}") logging.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}")
except Exception as e: except Exception as e:
logging.debug(f"[Pedal] Error: {e}") logging.debug(f"[Pedal] Error: {e}")
@@ -292,38 +324,139 @@ def make_identity_processors():
return teleop_proc, robot_proc, obs_proc return teleop_proc, robot_proc, obs_proc
# ============================================================================
# RTC Inference Thread (from evaluate_with_rtc.py)
# ============================================================================
def rtc_inference_thread(
policy,
obs_holder: dict, # {"obs": filtered_obs, "features": observation_features} - set by main loop
hw_features: dict,
preprocessor,
postprocessor,
queue_holder: dict, # {"queue": ActionQueue} - mutable so we can update per episode
shutdown_event: Event,
policy_active: Event,
cfg: RaCRTCConfig,
):
"""Background thread that generates action chunks using RTC.
IMPORTANT: This thread does NOT access the robot directly!
It reads observations from obs_holder which is updated by the main loop.
This avoids race conditions on the CAN bus.
"""
logger.info("[RTC] Inference thread started (reads obs from main loop, no direct robot access)")
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps
policy_device = policy.config.device
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if not policy_active.is_set():
time.sleep(0.01)
continue
action_queue = queue_holder["queue"]
if action_queue is None:
time.sleep(0.01)
continue
# Get observation from shared holder (set by main loop)
obs_filtered = obs_holder.get("obs")
if obs_filtered is None:
time.sleep(0.01)
continue
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
# Build observation for policy (using obs from main loop)
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
# Convert to tensors (like evaluate_with_rtc.py)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.dataset.single_task]
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
# Preprocess and run inference
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
# Put actions in queue!
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
logger.debug(f"[RTC] Generated chunk, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
else:
time.sleep(0.01)
logger.info("[RTC] Inference thread shutting down")
# ============================================================================
# Main Rollout Loop
# ============================================================================
@safe_stop_image_writer @safe_stop_image_writer
def rac_rtc_rollout_loop( def rac_rtc_rollout_loop(
robot: Robot, robot: RobotWrapper,
teleop: Teleoperator, teleop: Teleoperator,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor,
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], postprocessor,
dataset: LeRobotDataset, dataset: LeRobotDataset,
events: dict, events: dict,
fps: int, cfg: RaCRTCConfig,
control_time_s: float, queue_holder: dict,
single_task: str, obs_holder: dict, # Main loop writes obs here for RTC thread to read
display_data: bool = True, policy_active: Event,
use_rtc: bool = True, hw_features: dict,
rtc_config: RTCConfig | None = None,
interpolation: bool = False,
device: str = "cuda",
) -> dict: ) -> dict:
""" """RaC rollout loop with RTC for smooth policy execution."""
RaC rollout loop with optional RTC for smooth policy execution. fps = cfg.dataset.fps
single_task = cfg.dataset.single_task
control_time_s = cfg.dataset.episode_time_s
device = get_safe_torch_device(cfg.device)
Matches the original rac_data_collection_openarms.py structure exactly, # Reset policy state
but uses RTC action queue for smoother motion when use_rtc=True.
"""
# Reset policy and processors - EXACTLY like original
policy.reset() policy.reset()
preprocessor.reset() preprocessor.reset()
postprocessor.reset() postprocessor.reset()
device = get_safe_torch_device(device)
frame_buffer = [] frame_buffer = []
stats = { stats = {
"total_frames": 0, "total_frames": 0,
"autonomous_frames": 0, "autonomous_frames": 0,
@@ -331,28 +464,26 @@ def rac_rtc_rollout_loop(
"correction_frames": 0, "correction_frames": 0,
} }
# Start with teleop torque disabled - EXACTLY like original
teleop.disable_torque() teleop.disable_torque()
was_paused = False was_paused = False
waiting_for_takeover = False waiting_for_takeover = False
# RTC state (only used when use_rtc=True) # Action keys for converting tensor to dict
action_queue = None action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
latency_tracker = None
time_per_chunk = 1.0 / fps # Interpolation state
prev_action: Tensor | None = None prev_action: Tensor | None = None
interpolated_actions: list[Tensor] = [] interpolated_actions: list[Tensor] = []
interp_idx = 0 interp_idx = 0
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
if use_rtc and rtc_config: if cfg.interpolation:
action_queue = ActionQueue(rtc_config) control_interval = 1.0 / (fps * 2) # 2x rate
latency_tracker = LatencyTracker() else:
get_actions_threshold = 30 if rtc_config.enabled else 0 control_interval = 1.0 / fps
robot_action = {}
timestamp = 0 timestamp = 0
start_t = time.perf_counter() start_t = time.perf_counter()
robot_action = {} # Initialize for log_rerun_data
while timestamp < control_time_s: while timestamp < control_time_s:
loop_start = time.perf_counter() loop_start = time.perf_counter()
@@ -363,37 +494,41 @@ def rac_rtc_rollout_loop(
events["correction_active"] = False events["correction_active"] = False
break break
# Detect transition to paused state - EXACTLY like original # State transition: entering paused state
if events["policy_paused"] and not was_paused: if events["policy_paused"] and not was_paused:
policy_active.clear() # Stop RTC inference
obs = robot.get_observation() obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
print("[RaC] Moving teleop to robot position (2s smooth transition)...") print("[RaC] Moving teleop to robot position...")
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print("[RaC] Teleop aligned. Press START to take control.") print("[RaC] Teleop aligned. Press 'c' to take control.")
events["start_next_episode"] = False events["start_next_episode"] = False
waiting_for_takeover = True waiting_for_takeover = True
was_paused = True was_paused = True
# Reset interpolation state # Reset interpolation
prev_action = None prev_action = None
interpolated_actions = [] interpolated_actions = []
interp_idx = 0 interp_idx = 0
# Wait for start button - EXACTLY like original # Wait for takeover
if waiting_for_takeover and events["start_next_episode"]: if waiting_for_takeover and events["start_next_episode"]:
print("[RaC] Start pressed - enabling teleop control...") print("[RaC] Taking control...")
teleop.disable_torque() teleop.disable_torque()
events["start_next_episode"] = False events["start_next_episode"] = False
events["correction_active"] = True events["correction_active"] = True
waiting_for_takeover = False waiting_for_takeover = False
# Get observation - EXACTLY like original # Get observation (ONLY the main loop reads from robot!)
obs = robot.get_observation() obs = robot.get_observation()
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features} obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR) obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
# Share observation with RTC thread (thread reads, main loop writes)
obs_holder["obs"] = obs_filtered
if events["correction_active"]: if events["correction_active"]:
# Human controlling - EXACTLY like original # Human controlling
robot_action = teleop.get_action() robot_action = teleop.get_action()
for key in robot_action: for key in robot_action:
if "gripper" in key: if "gripper" in key:
@@ -401,116 +536,67 @@ def rac_rtc_rollout_loop(
robot.send_action(robot_action) robot.send_action(robot_action)
stats["correction_frames"] += 1 stats["correction_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task} frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame) frame_buffer.append(frame)
stats["total_frames"] += 1 stats["total_frames"] += 1
elif waiting_for_takeover: elif waiting_for_takeover:
# Waiting for START - EXACTLY like original (no action sent to robot!)
stats["paused_frames"] += 1 stats["paused_frames"] += 1
elif events["policy_paused"]: elif events["policy_paused"]:
# Paused - teleop tracks robot - EXACTLY like original
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")} robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
teleop.send_feedback(robot_pos) teleop.send_feedback(robot_pos)
stats["paused_frames"] += 1 stats["paused_frames"] += 1
else: else:
# Policy execution - use RTC if enabled, otherwise original predict_action # Policy execution with RTC
if use_rtc and action_queue is not None: policy_active.set()
# RTC path: check if we need to generate more actions action_queue = queue_holder["queue"]
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter() # Get action from queue (with interpolation)
action_index_before_inference = action_queue.get_action_index() if interp_idx >= len(interpolated_actions):
prev_actions = action_queue.get_left_over() new_action = action_queue.get() if action_queue else None
if new_action is not None:
current_action = new_action.cpu()
inference_latency = latency_tracker.max() if cfg.interpolation and prev_action is not None:
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
# Run inference - using predict_action for consistency with original
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
new_latency = time.perf_counter() - current_time
latency_tracker.add(new_latency)
# Get action from queue
queue_action = action_queue.get()
if queue_action is not None:
current_action = queue_action.cpu() if isinstance(queue_action, Tensor) else queue_action
# Handle interpolation
if interpolation and prev_action is not None and isinstance(current_action, Tensor):
mid = prev_action + 0.5 * (current_action - prev_action) mid = prev_action + 0.5 * (current_action - prev_action)
interpolated_actions = [mid, current_action] interpolated_actions = [mid, current_action]
else: else:
interpolated_actions = [current_action] interpolated_actions = [current_action]
if isinstance(current_action, Tensor): prev_action = current_action
prev_action = current_action
interp_idx = 0 interp_idx = 0
if interp_idx < len(interpolated_actions):
action_to_send = interpolated_actions[interp_idx]
interp_idx += 1
robot_action = {}
for i, key in enumerate(action_keys):
if i < len(action_to_send):
robot_action[key] = action_to_send[i].item()
# Send interpolated action
if interp_idx < len(interpolated_actions):
action_to_send = interpolated_actions[interp_idx]
interp_idx += 1
if isinstance(action_to_send, Tensor):
robot_action = {}
for i, key in enumerate(action_keys):
if i < len(action_to_send):
robot_action[key] = action_to_send[i].item()
else:
robot_action = action_to_send
robot.send_action(robot_action)
stats["autonomous_frames"] += 1
# Record this frame
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame)
stats["total_frames"] += 1
else:
# Original path - EXACTLY like original rac_data_collection_openarms.py
action_values = predict_action(
observation=obs_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
robot.send_action(robot_action) robot.send_action(robot_action)
stats["autonomous_frames"] += 1 stats["autonomous_frames"] += 1
# Record this frame # Record at original fps
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": single_task} frame = {**obs_frame, **action_frame, "task": single_task}
frame_buffer.append(frame) frame_buffer.append(frame)
stats["total_frames"] += 1 stats["total_frames"] += 1
if display_data: if cfg.display_data:
log_rerun_data(observation=obs_filtered, action=robot_action) log_rerun_data(observation=obs_filtered, action=robot_action)
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt) sleep_time = control_interval - dt
if sleep_time > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_t timestamp = time.perf_counter() - start_t
# Ensure teleoperator torque is disabled at end - EXACTLY like original policy_active.clear()
teleop.disable_torque() teleop.disable_torque()
for frame in frame_buffer: for frame in frame_buffer:
@@ -519,14 +605,10 @@ def rac_rtc_rollout_loop(
return stats return stats
def reset_loop( def reset_loop(robot: RobotWrapper, teleop: Teleoperator, events: dict, fps: int):
robot: Robot, """Reset period where human repositions environment."""
teleop: Teleoperator,
events: dict,
fps: int,
):
print("\n" + "=" * 65) print("\n" + "=" * 65)
print(" [RaC] RESET - Moving teleop to robot position...") print(" [RaC] RESET")
print("=" * 65) print("=" * 65)
events["in_reset"] = True events["in_reset"] = True
@@ -536,7 +618,7 @@ def reset_loop(
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50) teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
print(" Teleop aligned. Press any key/pedal to enable teleoperation") print(" Press any key/pedal to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]: while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05) precise_sleep(0.05)
@@ -545,8 +627,7 @@ def reset_loop(
events["start_next_episode"] = False events["start_next_episode"] = False
teleop.disable_torque() teleop.disable_torque()
print(" Teleop enabled - move robot to starting position") print(" Teleop enabled - press any key/pedal to start episode")
print(" Press any key/pedal to start next episode")
while not events["start_next_episode"] and not events["stop_recording"]: while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter() loop_start = time.perf_counter()
@@ -565,6 +646,10 @@ def reset_loop(
events["correction_active"] = False events["correction_active"] = False
# ============================================================================
# Main Entry Point
# ============================================================================
@parser.wrap() @parser.wrap()
def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset: def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
"""Main RaC data collection function with RTC.""" """Main RaC data collection function with RTC."""
@@ -574,7 +659,7 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
if cfg.display_data: if cfg.display_data:
init_rerun(session_name="rac_rtc_collection_openarms") init_rerun(session_name="rac_rtc_collection_openarms")
robot = make_robot_from_config(cfg.robot) robot_raw = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) teleop = make_teleoperator_from_config(cfg.teleop)
teleop_proc, robot_proc, obs_proc = make_identity_processors() teleop_proc, robot_proc, obs_proc = make_identity_processors()
@@ -582,18 +667,21 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
dataset_features = combine_feature_dicts( dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features( aggregate_pipeline_dataset_features(
pipeline=teleop_proc, pipeline=teleop_proc,
initial_features=create_initial_features(action=robot.action_features), initial_features=create_initial_features(action=robot_raw.action_features),
use_videos=cfg.dataset.video, use_videos=cfg.dataset.video,
), ),
aggregate_pipeline_dataset_features( aggregate_pipeline_dataset_features(
pipeline=obs_proc, pipeline=obs_proc,
initial_features=create_initial_features(observation=robot.observation_features), initial_features=create_initial_features(observation=robot_raw.observation_features),
use_videos=cfg.dataset.video, use_videos=cfg.dataset.video,
), ),
) )
dataset = None dataset = None
listener = None listener = None
shutdown_event = Event()
policy_active = Event()
rtc_thread = None
try: try:
if cfg.resume: if cfg.resume:
@@ -602,73 +690,92 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
root=cfg.dataset.root, root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
) )
if hasattr(robot, "cameras") and robot.cameras: if hasattr(robot_raw, "cameras") and robot_raw.cameras:
dataset.start_image_writer( dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes, num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
) )
else: else:
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
cfg.dataset.repo_id, cfg.dataset.repo_id,
cfg.dataset.fps, cfg.dataset.fps,
root=cfg.dataset.root, root=cfg.dataset.root,
robot_type=robot.name, robot_type=robot_raw.name,
features=dataset_features, features=dataset_features,
use_videos=cfg.dataset.video, use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []), * len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size, batch_encoding_size=cfg.dataset.video_encoding_batch_size,
) )
# Load policy - same as original # Load policy
policy = None logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
preprocessor = None policy_class = get_policy_class(cfg.policy.type)
postprocessor = None policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
policy.config.rtc_config = cfg.rtc
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info(f"Policy loaded: {policy.name}")
# Setup preprocessor/postprocessor
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Connect robot and wrap for thread safety
robot_raw.connect()
robot = RobotWrapper(robot_raw)
if cfg.policy:
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
# Setup RTC if enabled
if cfg.rtc.enabled:
policy.config.rtc_config = cfg.rtc
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info(f"Policy loaded: {policy.name}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
teleop.connect() teleop.connect()
listener, events = init_rac_keyboard_listener() listener, events = init_rac_keyboard_listener()
# Shared state holders (main loop writes, RTC thread reads)
queue_holder = {"queue": ActionQueue(cfg.rtc)}
obs_holder = {"obs": None, "robot_type": robot.robot_type} # Main loop updates obs
# Start RTC inference thread
# NOTE: Thread does NOT access robot directly - reads from obs_holder
rtc_thread = Thread(
target=rtc_inference_thread,
args=(
policy,
obs_holder, # Thread reads obs from here (set by main loop)
hw_features,
preprocessor,
postprocessor,
queue_holder,
shutdown_event,
policy_active,
cfg,
),
daemon=True,
name="RTCInference",
)
rtc_thread.start()
logger.info("Started RTC inference thread")
print("\n" + "=" * 65) print("\n" + "=" * 65)
print(" RaC (Recovery and Correction) Data Collection with RTC") print(" RaC Data Collection with RTC")
print("=" * 65) print("=" * 65)
print(f" Policy: {cfg.policy.pretrained_path if cfg.policy else 'None'}") print(f" Policy: {cfg.policy.pretrained_path}")
print(f" Task: {cfg.dataset.single_task}") print(f" Task: {cfg.dataset.single_task}")
print(f" RTC Enabled: {cfg.rtc.enabled}")
print(f" Interpolation: {cfg.interpolation}")
print(f" FPS: {cfg.dataset.fps}") print(f" FPS: {cfg.dataset.fps}")
print(f" Interpolation: {cfg.interpolation}")
print() print()
print(" Controls:") print(" Controls:")
print(" SPACE - Pause policy (teleop tracks robot, no recording)") print(" SPACE - Pause policy")
print(" c - Take control (start correction, recording)") print(" c - Take control")
print(" → - End episode (save)") print(" → - End episode")
print(" - Re-record episode") print(" ESC - Stop and push to hub")
print(" ESC - Stop session and push to hub")
print("=" * 65 + "\n") print("=" * 65 + "\n")
with VideoEncodingManager(dataset): with VideoEncodingManager(dataset):
@@ -676,9 +783,10 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds) log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
logger.info(f"\n{'='*40}") # Fresh action queue per episode (update holder so thread sees it)
queue_holder["queue"] = ActionQueue(cfg.rtc)
logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}") logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}")
logger.info(f"{'='*40}")
stats = rac_rtc_rollout_loop( stats = rac_rtc_rollout_loop(
robot=robot, robot=robot,
@@ -688,14 +796,11 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
postprocessor=postprocessor, postprocessor=postprocessor,
dataset=dataset, dataset=dataset,
events=events, events=events,
fps=cfg.dataset.fps, cfg=cfg,
control_time_s=cfg.dataset.episode_time_s, queue_holder=queue_holder,
single_task=cfg.dataset.single_task, obs_holder=obs_holder,
display_data=cfg.display_data, policy_active=policy_active,
use_rtc=cfg.rtc.enabled, hw_features=hw_features,
rtc_config=cfg.rtc,
interpolation=cfg.interpolation,
device=cfg.device,
) )
logging.info(f"Episode stats: {stats}") logging.info(f"Episode stats: {stats}")
@@ -711,21 +816,22 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
recorded += 1 recorded += 1
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]: if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
reset_loop( reset_loop(robot, teleop, events, cfg.dataset.fps)
robot=robot,
teleop=teleop,
events=events,
fps=cfg.dataset.fps,
)
finally: finally:
log_say("Stop recording", cfg.play_sounds, blocking=True) log_say("Stop recording", cfg.play_sounds, blocking=True)
shutdown_event.set()
policy_active.clear()
if rtc_thread and rtc_thread.is_alive():
rtc_thread.join(timeout=2.0)
if dataset: if dataset:
dataset.finalize() dataset.finalize()
if robot.is_connected: if robot_raw.is_connected:
robot.disconnect() robot_raw.disconnect()
if teleop.is_connected: if teleop.is_connected:
teleop.disconnect() teleop.disconnect()