mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
by default dont use rtc
This commit is contained in:
@@ -54,9 +54,11 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -72,12 +74,12 @@ from lerobot.processor.converters import (
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||
from lerobot.teleoperators import TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
@@ -119,10 +121,10 @@ class RaCRTCConfig:
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
))
|
||||
|
||||
interpolation: bool = True
|
||||
@@ -290,138 +292,38 @@ def make_identity_processors():
|
||||
return teleop_proc, robot_proc, obs_proc
|
||||
|
||||
|
||||
class SharedState:
|
||||
"""Thread-safe shared state for RTC inference thread."""
|
||||
def __init__(self):
|
||||
self.obs: dict | None = None
|
||||
self.action_queue: ActionQueue | None = None
|
||||
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy,
|
||||
shared_state: SharedState,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: RaCRTCConfig,
|
||||
hw_features: dict,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
):
|
||||
"""Background thread that generates action chunks using RTC.
|
||||
|
||||
This thread:
|
||||
- Waits for policy_active to be set
|
||||
- Uses observation from shared_state.obs (set by main loop)
|
||||
- Generates action chunks and puts them in shared_state.action_queue
|
||||
"""
|
||||
logger.info("[RTC] Inference thread started (waiting for policy_active signal)")
|
||||
logger.info("[RTC] Thread is IDLE - will not do anything until main loop activates policy")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
policy_device = policy.config.device
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
was_active = False
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
if was_active:
|
||||
logger.info("[RTC] Policy deactivated, thread going idle")
|
||||
was_active = False
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if not was_active:
|
||||
logger.info("[RTC] Policy activated! Starting inference loop")
|
||||
was_active = True
|
||||
|
||||
action_queue = shared_state.action_queue
|
||||
if action_queue is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
obs = shared_state.obs
|
||||
if obs is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(hw_features, obs, prefix="observation")
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||
obs_with_policy_features["robot_type"] = "openarms_follower"
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
inference_count += 1
|
||||
if inference_count % 10 == 0:
|
||||
logger.debug(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
|
||||
else:
|
||||
time.sleep(0.005)
|
||||
|
||||
logger.info("[RTC] Inference thread shutting down")
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rac_rtc_rollout_loop(
|
||||
robot,
|
||||
teleop,
|
||||
shared_state: SharedState,
|
||||
policy_active: Event,
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: RaCRTCConfig,
|
||||
action_keys: list[str],
|
||||
fps: int,
|
||||
control_time_s: float,
|
||||
single_task: str,
|
||||
display_data: bool = True,
|
||||
use_rtc: bool = True,
|
||||
rtc_config: RTCConfig | None = None,
|
||||
interpolation: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> dict:
|
||||
"""RaC rollout loop with RTC for smooth policy execution."""
|
||||
logger.info("[ROLLOUT] Starting rollout loop...")
|
||||
|
||||
fps = cfg.dataset.fps
|
||||
single_task = cfg.dataset.single_task
|
||||
control_time_s = cfg.dataset.episode_time_s
|
||||
"""
|
||||
RaC rollout loop with optional RTC for smooth policy execution.
|
||||
|
||||
Matches the original rac_data_collection_openarms.py structure exactly,
|
||||
but uses RTC action queue for smoother motion when use_rtc=True.
|
||||
"""
|
||||
# Reset policy and processors - EXACTLY like original
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
device = get_safe_torch_device(device)
|
||||
frame_buffer = []
|
||||
|
||||
stats = {
|
||||
"total_frames": 0,
|
||||
"autonomous_frames": 0,
|
||||
@@ -429,33 +331,28 @@ def rac_rtc_rollout_loop(
|
||||
"correction_frames": 0,
|
||||
}
|
||||
|
||||
# Start with teleop torque disabled - EXACTLY like original
|
||||
teleop.disable_torque()
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Interpolation state
|
||||
|
||||
# RTC state (only used when use_rtc=True)
|
||||
action_queue = None
|
||||
latency_tracker = None
|
||||
time_per_chunk = 1.0 / fps
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
if cfg.interpolation:
|
||||
interp_factor = 2
|
||||
control_interval = 1.0 / (fps * interp_factor)
|
||||
logger.info(f"[ROLLOUT] Interpolation ON: {fps}Hz -> {fps * interp_factor}Hz")
|
||||
else:
|
||||
interp_factor = 1
|
||||
control_interval = 1.0 / fps
|
||||
logger.info(f"[ROLLOUT] Interpolation OFF: {fps}Hz")
|
||||
|
||||
# Hz tracking
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_time = time.perf_counter()
|
||||
last_record_time = 0.0
|
||||
|
||||
if use_rtc and rtc_config:
|
||||
action_queue = ActionQueue(rtc_config)
|
||||
latency_tracker = LatencyTracker()
|
||||
get_actions_threshold = 30 if rtc_config.enabled else 0
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
first_iteration = True
|
||||
robot_action = {} # Initialize for log_rerun_data
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
@@ -466,21 +363,10 @@ def rac_rtc_rollout_loop(
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Get observation (always from main thread - only place robot is read)
|
||||
if first_iteration:
|
||||
logger.info("[ROLLOUT] First iteration - reading observation from robot...")
|
||||
obs = robot.get_observation()
|
||||
if first_iteration:
|
||||
logger.info("[ROLLOUT] First observation read OK")
|
||||
first_iteration = False
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
|
||||
# Update shared observation for RTC thread
|
||||
shared_state.obs = obs_filtered
|
||||
|
||||
# State transition: entering paused state
|
||||
# Detect transition to paused state - EXACTLY like original
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear() # Stop RTC inference
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
print("[RaC] Moving teleop to robot position (2s smooth transition)...")
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
@@ -493,7 +379,7 @@ def rac_rtc_rollout_loop(
|
||||
interpolated_actions = []
|
||||
interp_idx = 0
|
||||
|
||||
# Wait for start button before enabling correction mode
|
||||
# Wait for start button - EXACTLY like original
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
print("[RaC] Start pressed - enabling teleop control...")
|
||||
teleop.disable_torque()
|
||||
@@ -501,98 +387,130 @@ def rac_rtc_rollout_loop(
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
# Get observation - EXACTLY like original
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
# Human controlling - record correction data
|
||||
# Human controlling - EXACTLY like original
|
||||
robot_action = teleop.get_action()
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
robot_send_count += 1
|
||||
stats["correction_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
elif waiting_for_takeover:
|
||||
# Waiting for START - policy stopped, no recording
|
||||
# Waiting for START - EXACTLY like original (no action sent to robot!)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
elif events["policy_paused"]:
|
||||
# Paused - teleop tracks robot position
|
||||
# Paused - teleop tracks robot - EXACTLY like original
|
||||
robot_pos = {k: v for k, v in obs_filtered.items() if k.endswith(".pos")}
|
||||
teleop.send_feedback(robot_pos)
|
||||
stats["paused_frames"] += 1
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
policy_active.set()
|
||||
action_queue = shared_state.action_queue
|
||||
|
||||
# Get next action from queue (with interpolation)
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get() if action_queue else None
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
policy_consume_count += 1
|
||||
# Policy execution - use RTC if enabled, otherwise original predict_action
|
||||
if use_rtc and action_queue is not None:
|
||||
# RTC path: check if we need to generate more actions
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
# Run inference - using predict_action for consistency with original
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
# Get action from queue
|
||||
queue_action = action_queue.get()
|
||||
if queue_action is not None:
|
||||
current_action = queue_action.cpu() if isinstance(queue_action, Tensor) else queue_action
|
||||
|
||||
# Handle interpolation
|
||||
if interpolation and prev_action is not None and isinstance(current_action, Tensor):
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
if isinstance(current_action, Tensor):
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
action_dict[key] = action_to_send[i].item()
|
||||
|
||||
robot.send_action(action_dict)
|
||||
robot_send_count += 1
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record at dataset fps (not interpolated rate)
|
||||
now = time.perf_counter()
|
||||
if now - last_record_time >= (1.0 / fps):
|
||||
last_record_time = now
|
||||
action_frame = build_dataset_frame(dataset.features, action_dict, prefix=ACTION)
|
||||
# Send interpolated action
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
if isinstance(action_to_send, Tensor):
|
||||
robot_action = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
robot_action[key] = action_to_send[i].item()
|
||||
else:
|
||||
robot_action = action_to_send
|
||||
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
else:
|
||||
# Original path - EXACTLY like original rac_data_collection_openarms.py
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot.send_action(robot_action)
|
||||
stats["autonomous_frames"] += 1
|
||||
|
||||
# Record this frame
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": single_task}
|
||||
frame_buffer.append(frame)
|
||||
stats["total_frames"] += 1
|
||||
|
||||
# Print Hz stats every 5 seconds
|
||||
now = time.perf_counter()
|
||||
if now - last_hz_time >= 5.0:
|
||||
elapsed = now - last_hz_time
|
||||
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
|
||||
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
|
||||
mode = "CORRECTION" if events["correction_active"] else ("PAUSED" if events["policy_paused"] else "POLICY")
|
||||
logger.info(f"[Hz] Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}, Mode: {mode}")
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_time = now
|
||||
|
||||
if cfg.display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=action_dict if 'action_dict' in dir() else {})
|
||||
if display_data:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_time = control_interval - dt
|
||||
if sleep_time > 0:
|
||||
precise_sleep(sleep_time)
|
||||
precise_sleep(1 / fps - dt)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
# Ensure teleoperator torque is disabled at end - EXACTLY like original
|
||||
teleop.disable_torque()
|
||||
|
||||
for frame in frame_buffer:
|
||||
@@ -601,8 +519,12 @@ def rac_rtc_rollout_loop(
|
||||
return stats
|
||||
|
||||
|
||||
def reset_loop(robot, teleop, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
def reset_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
events: dict,
|
||||
fps: int,
|
||||
):
|
||||
print("\n" + "=" * 65)
|
||||
print(" [RaC] RESET - Moving teleop to robot position...")
|
||||
print("=" * 65)
|
||||
@@ -672,10 +594,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
shared_state = SharedState()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
@@ -703,73 +621,47 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
# Load policy - same as original
|
||||
policy = None
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
|
||||
if cfg.policy:
|
||||
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
# Setup RTC if enabled
|
||||
if cfg.rtc.enabled:
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Setup preprocessor/postprocessor for RTC thread
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
logger.info("Robot connected, waiting for CAN bus to stabilize...")
|
||||
time.sleep(1.0) # Let CAN bus stabilize
|
||||
|
||||
# Test read to verify robot communication is working
|
||||
logger.info("Testing robot communication...")
|
||||
test_obs = robot.get_observation()
|
||||
logger.info(f"Robot test read OK, got {len(test_obs)} observation keys")
|
||||
time.sleep(0.5)
|
||||
|
||||
teleop.connect()
|
||||
listener, events = init_rac_keyboard_listener()
|
||||
|
||||
# Get action keys for the robot
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
logger.info(f"Action keys: {action_keys}")
|
||||
|
||||
# Start RTC inference thread (it will be idle until policy_active is set)
|
||||
logger.info("Starting RTC inference thread (will be idle until episode starts)...")
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
shared_state,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
),
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
rtc_thread.start()
|
||||
logger.info("Started RTC inference thread")
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" RaC (Recovery and Correction) Data Collection with RTC")
|
||||
print("=" * 65)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Policy: {cfg.policy.pretrained_path if cfg.policy else 'None'}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" RTC Enabled: {cfg.rtc.enabled}")
|
||||
print(f" Interpolation: {cfg.interpolation}")
|
||||
print(f" Policy Hz: {cfg.dataset.fps}, Robot Hz: {cfg.dataset.fps * 2 if cfg.interpolation else cfg.dataset.fps}")
|
||||
print(f" FPS: {cfg.dataset.fps}")
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy (teleop tracks robot, no recording)")
|
||||
@@ -784,10 +676,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
# Create fresh action queue for this episode
|
||||
shared_state.action_queue = ActionQueue(cfg.rtc)
|
||||
shared_state.obs = None
|
||||
|
||||
logger.info(f"\n{'='*40}")
|
||||
logger.info(f"Episode {recorded + 1} / {cfg.dataset.num_episodes}")
|
||||
logger.info(f"{'='*40}")
|
||||
@@ -795,12 +683,19 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
stats = rac_rtc_rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
shared_state=shared_state,
|
||||
policy_active=policy_active,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
action_keys=action_keys,
|
||||
fps=cfg.dataset.fps,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
use_rtc=cfg.rtc.enabled,
|
||||
rtc_config=cfg.rtc,
|
||||
interpolation=cfg.interpolation,
|
||||
device=cfg.device,
|
||||
)
|
||||
|
||||
logging.info(f"Episode stats: {stats}")
|
||||
@@ -825,13 +720,6 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
logger.info("Waiting for RTC thread to finish...")
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
Reference in New Issue
Block a user