mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
1185 lines
44 KiB
Python
1185 lines
44 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Human-in-the-Loop (HIL) Data Collection with optional Real-Time Chunking (RTC).
|
|
|
|
Implements the RaC paradigm (https://arxiv.org/abs/2509.07953) for LeRobot. By default uses synchronous
|
|
inference (best for fast models like ACT / Diffusion Policy). Set --rtc.enabled=true for
|
|
asynchronous background inference (recommended for large models like Pi0 / Pi0.5 / SmolVLA).
|
|
|
|
The workflow:
|
|
1. Policy runs autonomously
|
|
2. Press SPACE to pause - robot holds position
|
|
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
|
4. Press 'p' to hand control back to policy and continue recording
|
|
5. Press → to end episode (save and continue to next)
|
|
6. Reset, then do next rollout
|
|
|
|
Keyboard Controls:
|
|
SPACE - Pause policy (robot holds position, no recording)
|
|
c - Take control (start correction, recording resumes)
|
|
p - Resume policy after pause/correction (recording continues)
|
|
→ - End episode (save and continue to next)
|
|
← - Re-record episode
|
|
ESC - Stop recording and push dataset to hub
|
|
|
|
Usage:
|
|
# Standard synchronous inference (ACT, Diffusion Policy)
|
|
python examples/hil/hil_data_collection.py \
|
|
--robot.type=bi_openarm_follower \
|
|
--teleop.type=openarm_mini \
|
|
--policy.path=path/to/pretrained_model \
|
|
--dataset.repo_id=user/hil-dataset \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--interpolation_multiplier=2
|
|
|
|
# With RTC for large models (Pi0, Pi0.5, SmolVLA)
|
|
python examples/hil/hil_data_collection.py \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--rtc.max_guidance_weight=5.0 \
|
|
--rtc.prefix_attention_schedule=LINEAR \
|
|
--robot.type=bi_openarm_follower \
|
|
--teleop.type=openarm_mini \
|
|
--policy.path=path/to/pretrained_model \
|
|
--dataset.repo_id=user/hil-dataset \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--interpolation_multiplier=3
|
|
|
|
# RTC with bi_openarm_follower + OpenArm Mini teleop and pi0.5 policy
|
|
python examples/hil/hil_data_collection.py \
|
|
--policy.path=lerobot-data-collection/folding_final \
|
|
--robot.type=bi_openarm_follower \
|
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
|
--robot.left_arm_config.port=can0 \
|
|
--robot.left_arm_config.side=left \
|
|
--robot.left_arm_config.can_interface=socketcan \
|
|
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.left_arm_config.max_relative_target=8.0 \
|
|
--robot.right_arm_config.port=can1 \
|
|
--robot.right_arm_config.side=right \
|
|
--robot.right_arm_config.can_interface=socketcan \
|
|
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.right_arm_config.max_relative_target=8.0 \
|
|
--teleop.type=openarm_mini \
|
|
--teleop.port_left=/dev/ttyACM1 \
|
|
--teleop.port_right=/dev/ttyACM0 \
|
|
--dataset.repo_id=lerobot-data-collection/hil_folding \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--dataset.num_episodes=50 \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--rtc.max_guidance_weight=5.0 \
|
|
--rtc.prefix_attention_schedule=LINEAR \
|
|
--interpolation_multiplier=3 \
|
|
--calibrate=true \
|
|
--device=cuda
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pprint import pformat
|
|
from threading import Event, Lock, Thread
|
|
from typing import Any
|
|
|
|
import torch
|
|
from hil_utils import (
|
|
HILDatasetConfig,
|
|
init_keyboard_listener,
|
|
make_identity_processors,
|
|
print_controls,
|
|
reset_loop,
|
|
teleop_disable_torque,
|
|
teleop_smooth_move_to,
|
|
)
|
|
|
|
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
|
from lerobot.common.control_utils import is_headless, predict_action
|
|
from lerobot.configs import PreTrainedConfig, parser
|
|
from lerobot.datasets import (
|
|
LeRobotDataset,
|
|
VideoEncodingManager,
|
|
aggregate_pipeline_dataset_features,
|
|
create_initial_features,
|
|
safe_stop_image_writer,
|
|
)
|
|
from lerobot.policies import PreTrainedPolicy, get_policy_class, make_policy, make_pre_post_processors
|
|
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
|
from lerobot.policies.utils import make_robot_action
|
|
from lerobot.processor import (
|
|
NormalizerProcessorStep,
|
|
PolicyProcessorPipeline,
|
|
RelativeActionsProcessorStep,
|
|
TransitionKey,
|
|
create_transition,
|
|
rename_stats,
|
|
to_relative_actions,
|
|
)
|
|
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
|
from lerobot.robots.bi_openarm_follower import BiOpenArmFollowerConfig
|
|
from lerobot.robots.so_follower import SOFollowerRobotConfig # noqa: F401
|
|
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
|
from lerobot.teleoperators.openarm_mini import OpenArmMiniConfig # noqa: F401
|
|
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig # noqa: F401
|
|
from lerobot.utils import get_safe_torch_device
|
|
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
|
from lerobot.utils.robot_utils import precise_sleep
|
|
from lerobot.utils.utils import init_logging, log_say
|
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# RTC helpers
|
|
|
|
|
|
class ThreadSafeRobot:
|
|
"""Thread-safe wrapper for robot operations (used with RTC background thread)."""
|
|
|
|
def __init__(self, robot: Robot):
|
|
self._robot = robot
|
|
self._lock = Lock()
|
|
|
|
def get_observation(self) -> dict[str, Any]:
|
|
with self._lock:
|
|
return self._robot.get_observation()
|
|
|
|
def send_action(self, action: dict) -> None:
|
|
with self._lock:
|
|
self._robot.send_action(action)
|
|
|
|
@property
|
|
def observation_features(self) -> dict:
|
|
return self._robot.observation_features
|
|
|
|
@property
|
|
def action_features(self) -> dict:
|
|
return self._robot.action_features
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._robot.name
|
|
|
|
@property
|
|
def robot_type(self) -> str:
|
|
return self._robot.robot_type
|
|
|
|
@property
|
|
def cameras(self):
|
|
return getattr(self._robot, "cameras", {})
|
|
|
|
|
|
def _set_openarm_max_relative_target_if_missing(
|
|
robot_cfg: RobotConfig, max_relative_target: float = 8.0
|
|
) -> None:
|
|
"""Set a safe default max_relative_target for OpenArm followers when not provided."""
|
|
if isinstance(robot_cfg, BiOpenArmFollowerConfig):
|
|
if robot_cfg.left_arm_config.max_relative_target is None:
|
|
robot_cfg.left_arm_config.max_relative_target = max_relative_target
|
|
if robot_cfg.right_arm_config.max_relative_target is None:
|
|
robot_cfg.right_arm_config.max_relative_target = max_relative_target
|
|
|
|
|
|
def _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute: torch.Tensor,
|
|
current_state: torch.Tensor,
|
|
relative_step: RelativeActionsProcessorStep | None,
|
|
normalizer_step: NormalizerProcessorStep | None,
|
|
policy_device: torch.device | str,
|
|
) -> torch.Tensor:
|
|
"""Convert absolute leftovers into model space for relative-action RTC policies."""
|
|
if relative_step is None:
|
|
return prev_actions_absolute.to(policy_device)
|
|
|
|
state = current_state.detach().cpu()
|
|
if state.dim() == 1:
|
|
state = state.unsqueeze(0)
|
|
|
|
action_cpu = prev_actions_absolute.detach().cpu()
|
|
mask = relative_step._build_mask(action_cpu.shape[-1])
|
|
relative_actions = to_relative_actions(action_cpu, state, mask)
|
|
|
|
transition = create_transition(action=relative_actions)
|
|
if normalizer_step is not None:
|
|
transition = normalizer_step(transition)
|
|
|
|
return transition[TransitionKey.ACTION].to(policy_device)
|
|
|
|
|
|
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
|
"""Pad/truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
|
if prev_actions.ndim != 2:
|
|
raise ValueError(f"Expected prev_actions to be 2D [T, A], got shape={tuple(prev_actions.shape)}")
|
|
|
|
steps, action_dim = prev_actions.shape
|
|
if steps == target_steps:
|
|
return prev_actions
|
|
if steps > target_steps:
|
|
return prev_actions[:target_steps]
|
|
|
|
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
|
padded[:steps] = prev_actions
|
|
return padded
|
|
|
|
|
|
def _resolve_action_key_order(cfg, dataset_action_names: list[str]) -> list[str]:
|
|
"""Choose action name ordering used to map policy tensor outputs to robot action dict."""
|
|
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
|
if not policy_action_names:
|
|
return dataset_action_names
|
|
|
|
policy_action_names = list(policy_action_names)
|
|
if len(policy_action_names) != len(dataset_action_names):
|
|
logger.warning(
|
|
"[RTC] policy.action_feature_names length (%d) != dataset action dim (%d); "
|
|
"falling back to dataset order",
|
|
len(policy_action_names),
|
|
len(dataset_action_names),
|
|
)
|
|
return dataset_action_names
|
|
|
|
if set(dataset_action_names) != set(policy_action_names):
|
|
logger.warning(
|
|
"[RTC] policy.action_feature_names keys do not match dataset action keys; "
|
|
"falling back to dataset order"
|
|
)
|
|
return dataset_action_names
|
|
|
|
return policy_action_names
|
|
|
|
|
|
def _resolve_state_joint_order(
|
|
policy_action_names: list[str] | None,
|
|
available_joint_names: list[str],
|
|
) -> list[str]:
|
|
"""Resolve joint-state ordering used to build observation.state."""
|
|
if not policy_action_names:
|
|
return available_joint_names
|
|
|
|
policy_action_names = list(policy_action_names)
|
|
available_set = set(available_joint_names)
|
|
policy_set = set(policy_action_names)
|
|
|
|
if len(policy_action_names) != len(available_joint_names) or policy_set != available_set:
|
|
logger.warning(
|
|
"policy.action_feature_names does not match available state joints; "
|
|
"falling back to robot observation order"
|
|
)
|
|
return available_joint_names
|
|
|
|
logger.info("Using policy.action_feature_names order for observation.state mapping")
|
|
return policy_action_names
|
|
|
|
|
|
def _start_pedal_listener(events: dict):
|
|
"""Start foot pedal listener thread if evdev is available.
|
|
|
|
Pedal input is restricted to HIL control handoff only:
|
|
policy -> pause -> takeover -> resume policy.
|
|
Episode save/advance remains keyboard-only (right arrow).
|
|
"""
|
|
import threading
|
|
|
|
try:
|
|
from evdev import InputDevice, categorize, ecodes
|
|
except ImportError:
|
|
logging.warning("[Pedal] evdev not installed - pedal support disabled")
|
|
return
|
|
|
|
pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
|
key_left = "KEY_A"
|
|
key_right = "KEY_C"
|
|
|
|
def pedal_reader():
|
|
try:
|
|
dev = InputDevice(pedal_device)
|
|
logger.info(f"[Pedal] Connected: {dev.name}")
|
|
|
|
for ev in dev.read_loop():
|
|
if ev.type != ecodes.EV_KEY:
|
|
continue
|
|
|
|
key = categorize(ev)
|
|
code = key.keycode
|
|
if isinstance(code, (list, tuple)):
|
|
code = code[0]
|
|
|
|
if key.keystate != 1:
|
|
continue
|
|
|
|
if events["in_reset"]:
|
|
if code in [key_left, key_right]:
|
|
events["start_next_episode"] = True
|
|
else:
|
|
if code not in [key_left, key_right]:
|
|
continue
|
|
|
|
if events["correction_active"]:
|
|
events["resume_policy"] = True
|
|
elif events["policy_paused"]:
|
|
events["start_next_episode"] = True
|
|
else:
|
|
events["policy_paused"] = True
|
|
|
|
except FileNotFoundError:
|
|
logging.info(f"[Pedal] Device not found: {pedal_device}")
|
|
except PermissionError:
|
|
logging.warning(f"[Pedal] Permission denied for {pedal_device}")
|
|
except Exception as e:
|
|
logging.warning(f"[Pedal] Error: {e}")
|
|
|
|
thread = threading.Thread(target=pedal_reader, daemon=True)
|
|
thread.start()
|
|
|
|
|
|
def _rtc_inference_thread(
|
|
policy: PreTrainedPolicy,
|
|
obs_holder: dict,
|
|
obs_lock: Lock,
|
|
hw_features: dict,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
queue_holder: dict,
|
|
shutdown_event: Event,
|
|
policy_active: Event,
|
|
compile_warmup_done: Event,
|
|
cfg,
|
|
):
|
|
"""Background thread for RTC action chunk generation."""
|
|
latency_tracker = LatencyTracker()
|
|
time_per_chunk = 1.0 / cfg.dataset.fps
|
|
threshold = 30
|
|
policy_device = policy.config.device
|
|
stats_window_start = time.perf_counter()
|
|
policy_inference_count = 0
|
|
latency_sum_s = 0.0
|
|
inference_count = 0
|
|
warmup_required = max(1, int(cfg.compile_warmup_inferences)) if cfg.use_torch_compile else 0
|
|
|
|
relative_step = next(
|
|
(
|
|
step
|
|
for step in preprocessor.steps
|
|
if isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
|
),
|
|
None,
|
|
)
|
|
normalizer_step = next(
|
|
(step for step in preprocessor.steps if isinstance(step, NormalizerProcessorStep)),
|
|
None,
|
|
)
|
|
if relative_step is not None:
|
|
if relative_step.action_names is None:
|
|
cfg_action_names = getattr(cfg.policy, "action_feature_names", None)
|
|
if cfg_action_names:
|
|
relative_step.action_names = list(cfg_action_names)
|
|
else:
|
|
fallback_action_names = obs_holder.get("action_feature_names")
|
|
if fallback_action_names:
|
|
relative_step.action_names = list(fallback_action_names)
|
|
logger.info("[RTC] Relative actions enabled: re-anchoring RTC prefix to current state")
|
|
|
|
while not shutdown_event.is_set():
|
|
if not policy_active.is_set():
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
queue = queue_holder.get("queue")
|
|
with obs_lock:
|
|
obs = obs_holder.get("obs")
|
|
if queue is None or obs is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
if queue.qsize() <= threshold:
|
|
try:
|
|
current_time = time.perf_counter()
|
|
idx_before = queue.get_action_index()
|
|
prev_actions = queue.get_left_over()
|
|
|
|
latency = latency_tracker.max()
|
|
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
|
|
|
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
|
|
for name in obs_batch:
|
|
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
|
if "image" in name:
|
|
obs_batch[name] = obs_batch[name].float() / 255
|
|
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
|
obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device)
|
|
|
|
obs_batch["task"] = [cfg.dataset.single_task]
|
|
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
|
|
|
|
preprocessed = preprocessor(obs_batch)
|
|
|
|
if prev_actions is not None and relative_step is not None and OBS_STATE in obs_batch:
|
|
prev_actions_absolute = queue.get_processed_left_over()
|
|
if prev_actions_absolute is not None and prev_actions_absolute.numel() > 0:
|
|
prev_actions = _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute=prev_actions_absolute,
|
|
current_state=obs_batch[OBS_STATE],
|
|
relative_step=relative_step,
|
|
normalizer_step=normalizer_step,
|
|
policy_device=policy_device,
|
|
)
|
|
|
|
if prev_actions is not None:
|
|
prev_actions = _normalize_prev_actions_length(
|
|
prev_actions, target_steps=cfg.rtc.execution_horizon
|
|
)
|
|
|
|
actions = policy.predict_action_chunk(
|
|
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
|
)
|
|
|
|
original = actions.squeeze(0).clone()
|
|
processed = postprocessor(actions).squeeze(0)
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
inference_count += 1
|
|
is_warmup_inference = cfg.use_torch_compile and inference_count <= warmup_required
|
|
if is_warmup_inference:
|
|
latency_tracker.reset()
|
|
else:
|
|
latency_tracker.add(new_latency)
|
|
queue.merge(original, processed, new_delay, idx_before)
|
|
policy_inference_count += 1
|
|
latency_sum_s += new_latency
|
|
if (
|
|
is_warmup_inference
|
|
and inference_count >= warmup_required
|
|
and not compile_warmup_done.is_set()
|
|
):
|
|
compile_warmup_done.set()
|
|
logger.info(
|
|
"[RTC] Compile warmup complete (%d/%d inferences)",
|
|
inference_count,
|
|
warmup_required,
|
|
)
|
|
logger.debug("[RTC] Inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
|
except Exception as e:
|
|
logger.error("[RTC] Error: %s", e)
|
|
time.sleep(0.5)
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
now = time.perf_counter()
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
policy_hz = policy_inference_count / window_elapsed
|
|
avg_latency_ms = (
|
|
(latency_sum_s / policy_inference_count * 1000.0) if policy_inference_count else 0.0
|
|
)
|
|
logger.info(
|
|
"[HIL RTC rates] policy=%.1f Hz | avg_inference=%.1f ms | queue=%d",
|
|
policy_hz,
|
|
avg_latency_ms,
|
|
queue.qsize(),
|
|
)
|
|
stats_window_start = now
|
|
policy_inference_count = 0
|
|
latency_sum_s = 0.0
|
|
|
|
|
|
# Config
|
|
|
|
|
|
@dataclass
|
|
class HILConfig:
|
|
robot: RobotConfig
|
|
teleop: TeleoperatorConfig
|
|
dataset: HILDatasetConfig
|
|
policy: PreTrainedConfig | None = None
|
|
rtc: RTCConfig = field(default_factory=RTCConfig)
|
|
interpolation_multiplier: int = 2
|
|
record_interpolated_actions: bool = False
|
|
display_data: bool = True
|
|
play_sounds: bool = True
|
|
resume: bool = False
|
|
device: str = "cuda"
|
|
use_torch_compile: bool = False
|
|
compile_warmup_inferences: int = 2
|
|
calibrate: bool = False
|
|
log_hz: bool = True
|
|
hz_log_interval_s: float = 2.0
|
|
|
|
def __post_init__(self):
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
if self.policy is None:
|
|
raise ValueError("policy.path is required")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
return ["policy"]
|
|
|
|
|
|
# Rollout loops
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def _rollout_sync(
|
|
robot: Robot,
|
|
teleop: Teleoperator,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
dataset: LeRobotDataset,
|
|
events: dict,
|
|
cfg: HILConfig,
|
|
):
|
|
"""Rollout loop with standard synchronous inference."""
|
|
fps = cfg.dataset.fps
|
|
device = get_safe_torch_device(cfg.device)
|
|
stream_online = bool(cfg.dataset.streaming_encoding)
|
|
record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier)
|
|
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
frame_buffer: list[dict] = []
|
|
teleop_disable_torque(teleop)
|
|
|
|
was_paused = False
|
|
waiting_for_takeover = False
|
|
last_action: dict[str, Any] | None = None
|
|
robot_action: dict[str, Any] = {}
|
|
action_keys = list(dataset.features[ACTION]["names"])
|
|
obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"])
|
|
obs_image_names = [
|
|
key.removeprefix(f"{OBS_STR}.images.")
|
|
for key in dataset.features
|
|
if key.startswith(f"{OBS_STR}.images.")
|
|
]
|
|
|
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
control_interval = interpolator.get_control_interval(fps)
|
|
|
|
timestamp = 0.0
|
|
record_tick = 0
|
|
start_t = time.perf_counter()
|
|
stats_window_start = start_t
|
|
policy_inference_count = 0
|
|
robot_command_count = 0
|
|
|
|
while timestamp < cfg.dataset.episode_time_s:
|
|
loop_start = time.perf_counter()
|
|
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
events["resume_policy"] = False
|
|
break
|
|
|
|
if events["resume_policy"] and (
|
|
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
|
|
):
|
|
events["resume_policy"] = False
|
|
events["start_next_episode"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
waiting_for_takeover = False
|
|
was_paused = False
|
|
last_action = None
|
|
interpolator.reset()
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
if events["policy_paused"] and not was_paused:
|
|
obs = robot.get_observation()
|
|
robot_pos = {
|
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
|
}
|
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
events["start_next_episode"] = False
|
|
waiting_for_takeover = True
|
|
was_paused = True
|
|
interpolator.reset()
|
|
|
|
if waiting_for_takeover and events["start_next_episode"]:
|
|
teleop_disable_torque(teleop)
|
|
events["start_next_episode"] = False
|
|
events["correction_active"] = True
|
|
waiting_for_takeover = False
|
|
|
|
obs = robot.get_observation()
|
|
obs_filtered = {k: obs[k] for k in obs_state_names if k in obs}
|
|
obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs})
|
|
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
|
|
|
if events["correction_active"]:
|
|
robot_action = teleop.get_action()
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
elif waiting_for_takeover or events["policy_paused"]:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
robot_command_count += 1
|
|
|
|
else:
|
|
if interpolator.needs_new_action():
|
|
action_values = predict_action(
|
|
observation=obs_frame,
|
|
policy=policy,
|
|
device=device,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
use_amp=policy.config.use_amp,
|
|
task=cfg.dataset.single_task,
|
|
robot_type=robot.robot_type,
|
|
)
|
|
policy_inference_count += 1
|
|
robot_action = make_robot_action(action_values, dataset.features)
|
|
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
|
|
interpolator.add(action_tensor)
|
|
|
|
interp_action = interpolator.get()
|
|
if interp_action is not None:
|
|
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
last_action = robot_action
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
if cfg.display_data and robot_action:
|
|
log_rerun_data(observation=obs_filtered, action=robot_action)
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_time := control_interval - dt) > 0:
|
|
precise_sleep(sleep_time)
|
|
now = time.perf_counter()
|
|
timestamp = now - start_t
|
|
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
policy_hz = policy_inference_count / window_elapsed
|
|
robot_hz = robot_command_count / window_elapsed
|
|
logger.info(
|
|
"[HIL rates] policy=%.1f Hz (target=%.1f) | robot=%.1f Hz (target=%.1f)",
|
|
policy_hz,
|
|
fps,
|
|
robot_hz,
|
|
fps * cfg.interpolation_multiplier,
|
|
)
|
|
stats_window_start = now
|
|
policy_inference_count = 0
|
|
robot_command_count = 0
|
|
|
|
teleop_disable_torque(teleop)
|
|
|
|
if not stream_online:
|
|
for frame in frame_buffer:
|
|
dataset.add_frame(frame)
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def _rollout_rtc(
|
|
robot,
|
|
teleop: Teleoperator,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
dataset: LeRobotDataset,
|
|
events: dict,
|
|
cfg: HILConfig,
|
|
queue_holder: dict,
|
|
obs_holder: dict,
|
|
obs_lock: Lock,
|
|
policy_active: Event,
|
|
compile_warmup_done: Event,
|
|
hw_features: dict,
|
|
):
|
|
"""Rollout loop with RTC for asynchronous inference."""
|
|
fps = cfg.dataset.fps
|
|
stream_online = bool(cfg.dataset.streaming_encoding)
|
|
record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier)
|
|
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
frame_buffer: list[dict] = []
|
|
teleop_disable_torque(teleop)
|
|
|
|
was_paused = False
|
|
waiting_for_takeover = False
|
|
last_action: dict[str, Any] | None = None
|
|
dataset_action_keys = list(dataset.features[ACTION]["names"])
|
|
action_keys = _resolve_action_key_order(cfg, dataset_action_keys)
|
|
if action_keys != dataset_action_keys:
|
|
logger.info("[RTC] Using policy.action_feature_names order for action tensor mapping")
|
|
else:
|
|
logger.info("[RTC] Using dataset action feature order for action tensor mapping")
|
|
obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"])
|
|
obs_image_names = [
|
|
key.removeprefix(f"{OBS_STR}.images.")
|
|
for key in dataset.features
|
|
if key.startswith(f"{OBS_STR}.images.")
|
|
]
|
|
|
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
control_interval = interpolator.get_control_interval(fps)
|
|
|
|
robot_action: dict[str, Any] = {}
|
|
timestamp = 0.0
|
|
start_t = time.perf_counter()
|
|
stats_window_start = start_t
|
|
robot_command_count = 0
|
|
record_tick = 0
|
|
obs_poll_interval = 1.0 / fps
|
|
last_obs_poll_t = 0.0
|
|
obs_filtered: dict[str, Any] = {}
|
|
obs_frame: dict[str, Any] = {}
|
|
warmup_wait_logged = False
|
|
warmup_queue_flushed = False
|
|
|
|
while timestamp < cfg.dataset.episode_time_s:
|
|
loop_start = time.perf_counter()
|
|
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
events["resume_policy"] = False
|
|
break
|
|
|
|
if events["resume_policy"] and (
|
|
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
|
|
):
|
|
events["resume_policy"] = False
|
|
events["start_next_episode"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
waiting_for_takeover = False
|
|
was_paused = False
|
|
last_action = None
|
|
interpolator.reset()
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
policy_active.clear()
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
if events["policy_paused"] and not was_paused:
|
|
policy_active.clear()
|
|
obs = robot.get_observation()
|
|
robot_pos = {
|
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
|
}
|
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
events["start_next_episode"] = False
|
|
waiting_for_takeover = True
|
|
was_paused = True
|
|
interpolator.reset()
|
|
|
|
if waiting_for_takeover and events["start_next_episode"]:
|
|
teleop_disable_torque(teleop)
|
|
events["start_next_episode"] = False
|
|
events["correction_active"] = True
|
|
waiting_for_takeover = False
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
|
|
now_for_obs = time.perf_counter()
|
|
should_poll_obs = (
|
|
not obs_filtered
|
|
or (now_for_obs - last_obs_poll_t) >= obs_poll_interval
|
|
or events["correction_active"]
|
|
or waiting_for_takeover
|
|
or events["policy_paused"]
|
|
)
|
|
if should_poll_obs:
|
|
obs = robot.get_observation()
|
|
obs_filtered = {k: obs[k] for k in obs_state_names if k in obs}
|
|
obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs})
|
|
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
|
with obs_lock:
|
|
obs_holder["obs"] = obs_filtered
|
|
last_obs_poll_t = now_for_obs
|
|
|
|
if events["correction_active"]:
|
|
robot_action = teleop.get_action()
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
elif waiting_for_takeover or events["policy_paused"]:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
robot_command_count += 1
|
|
|
|
else:
|
|
if not policy_active.is_set():
|
|
policy_active.set()
|
|
|
|
if cfg.use_torch_compile and not compile_warmup_done.is_set():
|
|
if not warmup_wait_logged:
|
|
logger.info(
|
|
"[RTC] Waiting for compile warmup (%d inferences) before policy rollout",
|
|
max(1, int(cfg.compile_warmup_inferences)),
|
|
)
|
|
warmup_wait_logged = True
|
|
else:
|
|
if cfg.use_torch_compile and not warmup_queue_flushed:
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
interpolator.reset()
|
|
warmup_queue_flushed = True
|
|
logger.info("[RTC] Warmup queue cleared; starting live policy rollout")
|
|
|
|
queue = queue_holder["queue"]
|
|
|
|
if interpolator.needs_new_action():
|
|
new_action = queue.get() if queue else None
|
|
if new_action is not None:
|
|
interpolator.add(new_action.cpu())
|
|
|
|
action_tensor = interpolator.get()
|
|
if action_tensor is not None:
|
|
robot_action = {
|
|
k: action_tensor[i].item()
|
|
for i, k in enumerate(action_keys)
|
|
if i < len(action_tensor)
|
|
}
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
last_action = robot_action
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_time := control_interval - dt) > 0:
|
|
precise_sleep(sleep_time)
|
|
now = time.perf_counter()
|
|
timestamp = now - start_t
|
|
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
robot_hz = robot_command_count / window_elapsed
|
|
logger.info(
|
|
"[HIL RTC rates] robot=%.1f Hz (target=%.1f)",
|
|
robot_hz,
|
|
fps * cfg.interpolation_multiplier,
|
|
)
|
|
stats_window_start = now
|
|
robot_command_count = 0
|
|
|
|
policy_active.clear()
|
|
teleop_disable_torque(teleop)
|
|
|
|
if not stream_online:
|
|
for frame in frame_buffer:
|
|
dataset.add_frame(frame)
|
|
|
|
|
|
# Main collection function
|
|
|
|
|
|
@parser.wrap()
|
|
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
|
"""Main HIL data collection function (supports both sync and RTC modes)."""
|
|
init_logging()
|
|
logger.info(pformat(cfg.__dict__))
|
|
|
|
use_rtc = cfg.rtc.enabled
|
|
|
|
if use_rtc:
|
|
_set_openarm_max_relative_target_if_missing(cfg.robot, max_relative_target=8.0)
|
|
|
|
if cfg.display_data:
|
|
init_rerun(session_name="hil_collection")
|
|
|
|
robot_raw = make_robot_from_config(cfg.robot)
|
|
teleop = make_teleoperator_from_config(cfg.teleop)
|
|
|
|
teleop_proc, obs_proc = make_identity_processors()
|
|
|
|
action_features_hw = {k: v for k, v in robot_raw.action_features.items() if k.endswith(".pos")}
|
|
all_observation_features = robot_raw.observation_features
|
|
available_joint_names = [
|
|
key for key, value in all_observation_features.items() if key.endswith(".pos") and value is float
|
|
]
|
|
ordered_joint_names = _resolve_state_joint_order(
|
|
getattr(cfg.policy, "action_feature_names", None),
|
|
available_joint_names,
|
|
)
|
|
observation_features_hw = {
|
|
joint_name: all_observation_features[joint_name] for joint_name in ordered_joint_names
|
|
}
|
|
for key, value in all_observation_features.items():
|
|
if isinstance(value, tuple):
|
|
observation_features_hw[key] = value
|
|
|
|
dataset_features = combine_feature_dicts(
|
|
aggregate_pipeline_dataset_features(
|
|
pipeline=teleop_proc,
|
|
initial_features=create_initial_features(action=action_features_hw),
|
|
use_videos=cfg.dataset.video,
|
|
),
|
|
aggregate_pipeline_dataset_features(
|
|
pipeline=obs_proc,
|
|
initial_features=create_initial_features(observation=observation_features_hw),
|
|
use_videos=cfg.dataset.video,
|
|
),
|
|
)
|
|
|
|
dataset = None
|
|
listener = None
|
|
shutdown_event = Event()
|
|
policy_active = Event()
|
|
compile_warmup_done = Event()
|
|
if not cfg.use_torch_compile:
|
|
compile_warmup_done.set()
|
|
rtc_thread = None
|
|
|
|
try:
|
|
if cfg.resume:
|
|
dataset = LeRobotDataset(
|
|
cfg.dataset.repo_id,
|
|
root=cfg.dataset.root,
|
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
|
vcodec=cfg.dataset.vcodec,
|
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
|
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
|
encoder_threads=cfg.dataset.encoder_threads,
|
|
)
|
|
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
|
dataset.start_image_writer(
|
|
num_processes=cfg.dataset.num_image_writer_processes,
|
|
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
|
)
|
|
else:
|
|
dataset = LeRobotDataset.create(
|
|
cfg.dataset.repo_id,
|
|
cfg.dataset.fps,
|
|
root=cfg.dataset.root,
|
|
robot_type=robot_raw.name,
|
|
features=dataset_features,
|
|
use_videos=cfg.dataset.video,
|
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
|
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
|
vcodec=cfg.dataset.vcodec,
|
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
|
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
|
encoder_threads=cfg.dataset.encoder_threads,
|
|
)
|
|
|
|
# Load policy — RTC needs manual loading for predict_action_chunk support
|
|
if use_rtc:
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
if hasattr(policy_config, "compile_model"):
|
|
policy_config.compile_model = cfg.use_torch_compile
|
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
|
policy.config.rtc_config = cfg.rtc
|
|
if hasattr(policy, "init_rtc_processor"):
|
|
policy.init_rtc_processor()
|
|
policy = policy.to(cfg.device)
|
|
policy.eval()
|
|
else:
|
|
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": cfg.device},
|
|
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
|
},
|
|
)
|
|
|
|
# Connect hardware
|
|
if use_rtc:
|
|
logger.info("Connecting robot (calibrate=%s)", cfg.calibrate)
|
|
robot_raw.connect(calibrate=False)
|
|
if cfg.calibrate and hasattr(robot_raw, "calibrate"):
|
|
robot_raw.calibrate()
|
|
robot_raw.disconnect()
|
|
robot_raw.connect(calibrate=False)
|
|
else:
|
|
robot_raw.connect()
|
|
|
|
robot = ThreadSafeRobot(robot_raw) if use_rtc else robot_raw
|
|
teleop.connect()
|
|
listener, events = init_keyboard_listener()
|
|
|
|
# RTC-specific setup
|
|
queue_holder = None
|
|
obs_holder = None
|
|
obs_lock = Lock()
|
|
hw_features = None
|
|
if use_rtc:
|
|
_start_pedal_listener(events)
|
|
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
|
obs_holder = {
|
|
"obs": None,
|
|
"robot_type": robot.robot_type,
|
|
"action_feature_names": [key for key in robot.action_features if key.endswith(".pos")],
|
|
}
|
|
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
|
|
|
rtc_thread = Thread(
|
|
target=_rtc_inference_thread,
|
|
args=(
|
|
policy,
|
|
obs_holder,
|
|
obs_lock,
|
|
hw_features,
|
|
preprocessor,
|
|
postprocessor,
|
|
queue_holder,
|
|
shutdown_event,
|
|
policy_active,
|
|
compile_warmup_done,
|
|
cfg,
|
|
),
|
|
daemon=True,
|
|
)
|
|
rtc_thread.start()
|
|
|
|
print_controls(rtc=use_rtc)
|
|
logger.info(f" Policy: {cfg.policy.pretrained_path}")
|
|
logger.info(f" Task: {cfg.dataset.single_task}")
|
|
logger.info(f" Interpolation: {cfg.interpolation_multiplier}x")
|
|
if use_rtc:
|
|
logger.info(f" RTC: enabled (execution_horizon={cfg.rtc.execution_horizon})")
|
|
|
|
with VideoEncodingManager(dataset):
|
|
recorded = 0
|
|
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
|
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
|
|
|
if use_rtc:
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
_rollout_rtc(
|
|
robot=robot,
|
|
teleop=teleop,
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset=dataset,
|
|
events=events,
|
|
cfg=cfg,
|
|
queue_holder=queue_holder,
|
|
obs_holder=obs_holder,
|
|
obs_lock=obs_lock,
|
|
policy_active=policy_active,
|
|
compile_warmup_done=compile_warmup_done,
|
|
hw_features=hw_features,
|
|
)
|
|
else:
|
|
_rollout_sync(
|
|
robot=robot,
|
|
teleop=teleop,
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset=dataset,
|
|
events=events,
|
|
cfg=cfg,
|
|
)
|
|
|
|
if events["rerecord_episode"]:
|
|
log_say("Re-recording", cfg.play_sounds)
|
|
events["rerecord_episode"] = False
|
|
events["exit_early"] = False
|
|
dataset.clear_episode_buffer()
|
|
continue
|
|
|
|
dataset.save_episode()
|
|
recorded += 1
|
|
|
|
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
|
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
|
|
|
finally:
|
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
|
|
|
shutdown_event.set()
|
|
policy_active.clear()
|
|
|
|
if rtc_thread and rtc_thread.is_alive():
|
|
rtc_thread.join(timeout=2.0)
|
|
|
|
if dataset:
|
|
dataset.finalize()
|
|
|
|
if robot_raw.is_connected:
|
|
robot_raw.disconnect()
|
|
if teleop.is_connected:
|
|
teleop.disconnect()
|
|
|
|
if not is_headless() and listener:
|
|
listener.stop()
|
|
|
|
if cfg.dataset.push_to_hub and dataset is not None:
|
|
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
|
|
|
return dataset
|
|
|
|
|
|
def main():
|
|
from lerobot.utils.import_utils import register_third_party_plugins
|
|
|
|
register_third_party_plugins()
|
|
hil_collect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|