mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
818892a38b
* feat: HIL data collection, RTC interpolator, and action queue improvements - Add Human-in-the-Loop (HIL) data collection examples (sync + RTC) - Add HIL data collection documentation - Add ActionInterpolator for smoother policy control at higher rates - Integrate interpolator into lerobot-record and eval_with_real_robot - Add action queue clear() and get_processed_left_over() methods - Add rtc/__init__.py for cleaner imports * docs: expand Related Work section with paper summaries * fix: only record dataset frames at original fps, not at interpolated rate The interpolator speeds up robot control (e.g. 2x) but dataset frames should still be recorded at the original fps. Interpolated-only iterations now only send actions to the robot without writing to the dataset. * refactor: merge HIL sync and RTC scripts into single file with --rtc.enabled toggle Combines hil_data_collection.py and hil_data_collection_rtc.py into one script. RTC is toggled via --rtc.enabled=true (defaults to off for sync inference). Deletes the separate hil_data_collection_rtc.py and updates docs to reflect the single-script usage. * test: add ActionInterpolator test suite (29 tests) Covers constructor validation, passthrough (multiplier=1), 2x and 3x interpolation with exact value checks, reset/episode boundaries, control interval calculation, multi-dim actions, and simulated control loop integration. * test: add ActionQueue + ActionInterpolator integration tests Verifies the interpolator doesn't interfere with RTC's leftover chunk tracking: queue consumption rate matches base fps regardless of multiplier, get_left_over/get_processed_left_over only change on queue.get(), merge preserves smooth interpolation across chunks, and interpolator reset is independent of queue state. * feat: register SO follower/leader configs in HIL script Adds SOFollowerRobotConfig and SOLeaderTeleopConfig imports so SO100/SO101 robots can be used via --robot.type=so_follower and --teleop.type=so_leader. Updates docs accordingly. Made-with: Cursor * docs: remove em dashes from HIL documentation Made-with: Cursor * refactor: rename examples/rac to examples/hil Updates directory name and all references in docs and script docstrings. Made-with: Cursor * fix: encorperate pr feedback comments * refactor(tests): enhance ActionInterpolator test structure and add detailed docstrings * feedback pr and test fix * fix(test): pass correct real_delay in interpolator delay test The test was passing real_delay=0 and relying on _check_delays to silently override it with the index-based diff. Now passes real_delay=3 to match the 3 actions consumed during the simulated inference period. * fix pr feedback * ordering * update hil script * fix * default name * fix(bi_openarm): use kw_only=True to fix dataclass field ordering BiOpenArmFollowerConfig overrides `id` with a default, making it positional in the child — non-default `left_arm_config` then follows a default field, which Python dataclasses forbid. Adding kw_only=True (matching the parent RobotConfig) removes positional constraints. Made-with: Cursor * style: format long line in hil_data_collection.py Made-with: Cursor * pr feedback --------- Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
1184 lines
45 KiB
Python
1184 lines
45 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Human-in-the-Loop (HIL) Data Collection with optional Real-Time Chunking (RTC).
|
|
|
|
Implements the RaC paradigm (https://arxiv.org/abs/2509.07953) for LeRobot. By default uses synchronous
|
|
inference (best for fast models like ACT / Diffusion Policy). Set --rtc.enabled=true for
|
|
asynchronous background inference (recommended for large models like Pi0 / Pi0.5 / SmolVLA).
|
|
|
|
The workflow:
|
|
1. Policy runs autonomously
|
|
2. Press SPACE to pause - robot holds position
|
|
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
|
4. Press 'p' to hand control back to policy and continue recording
|
|
5. Press → to end episode (save and continue to next)
|
|
6. Reset, then do next rollout
|
|
|
|
Keyboard Controls:
|
|
SPACE - Pause policy (robot holds position, no recording)
|
|
c - Take control (start correction, recording resumes)
|
|
p - Resume policy after pause/correction (recording continues)
|
|
→ - End episode (save and continue to next)
|
|
← - Re-record episode
|
|
ESC - Stop recording and push dataset to hub
|
|
|
|
Usage:
|
|
# Standard synchronous inference (ACT, Diffusion Policy)
|
|
python examples/hil/hil_data_collection.py \
|
|
--robot.type=bi_openarm_follower \
|
|
--teleop.type=openarm_mini \
|
|
--policy.path=path/to/pretrained_model \
|
|
--dataset.repo_id=user/hil-dataset \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--interpolation_multiplier=2
|
|
|
|
# With RTC for large models (Pi0, Pi0.5, SmolVLA)
|
|
python examples/hil/hil_data_collection.py \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--rtc.max_guidance_weight=5.0 \
|
|
--rtc.prefix_attention_schedule=LINEAR \
|
|
--robot.type=bi_openarm_follower \
|
|
--teleop.type=openarm_mini \
|
|
--policy.path=path/to/pretrained_model \
|
|
--dataset.repo_id=user/hil-dataset \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--interpolation_multiplier=3
|
|
|
|
# RTC with bi_openarm_follower + OpenArm Mini teleop and pi0.5 policy
|
|
python examples/hil/hil_data_collection.py \
|
|
--policy.path=lerobot-data-collection/folding_final \
|
|
--robot.type=bi_openarm_follower \
|
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
|
--robot.left_arm_config.port=can0 \
|
|
--robot.left_arm_config.side=left \
|
|
--robot.left_arm_config.can_interface=socketcan \
|
|
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.left_arm_config.max_relative_target=8.0 \
|
|
--robot.right_arm_config.port=can1 \
|
|
--robot.right_arm_config.side=right \
|
|
--robot.right_arm_config.can_interface=socketcan \
|
|
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.right_arm_config.max_relative_target=8.0 \
|
|
--teleop.type=openarm_mini \
|
|
--teleop.port_left=/dev/ttyACM1 \
|
|
--teleop.port_right=/dev/ttyACM0 \
|
|
--dataset.repo_id=lerobot-data-collection/hil_folding \
|
|
--dataset.single_task="Fold the T-shirt properly" \
|
|
--dataset.fps=30 \
|
|
--dataset.num_episodes=50 \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--rtc.max_guidance_weight=5.0 \
|
|
--rtc.prefix_attention_schedule=LINEAR \
|
|
--interpolation_multiplier=3 \
|
|
--calibrate=true \
|
|
--device=cuda
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pprint import pformat
|
|
from threading import Event, Lock, Thread
|
|
from typing import Any
|
|
|
|
import torch
|
|
from hil_utils import (
|
|
HILDatasetConfig,
|
|
init_keyboard_listener,
|
|
make_identity_processors,
|
|
print_controls,
|
|
reset_loop,
|
|
teleop_disable_torque,
|
|
teleop_smooth_move_to,
|
|
)
|
|
|
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
|
from lerobot.datasets.image_writer import safe_stop_image_writer
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
|
from lerobot.policies.factory import get_policy_class, make_policy, make_pre_post_processors
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
|
from lerobot.policies.utils import make_robot_action
|
|
from lerobot.processor import (
|
|
NormalizerProcessorStep,
|
|
PolicyProcessorPipeline,
|
|
RelativeActionsProcessorStep,
|
|
TransitionKey,
|
|
create_transition,
|
|
)
|
|
from lerobot.processor.relative_action_processor import to_relative_actions
|
|
from lerobot.processor.rename_processor import rename_stats
|
|
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
|
from lerobot.robots.bi_openarm_follower.config_bi_openarm_follower import BiOpenArmFollowerConfig
|
|
from lerobot.robots.so_follower.config_so_follower import SOFollowerRobotConfig # noqa: F401
|
|
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
|
from lerobot.teleoperators.openarm_mini.config_openarm_mini import OpenArmMiniConfig # noqa: F401
|
|
from lerobot.teleoperators.so_leader.config_so_leader import SOLeaderTeleopConfig # noqa: F401
|
|
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
|
from lerobot.utils.control_utils import is_headless, predict_action
|
|
from lerobot.utils.device_utils import get_safe_torch_device
|
|
from lerobot.utils.robot_utils import precise_sleep
|
|
from lerobot.utils.utils import init_logging, log_say
|
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# RTC helpers
|
|
|
|
|
|
class ThreadSafeRobot:
|
|
"""Thread-safe wrapper for robot operations (used with RTC background thread)."""
|
|
|
|
def __init__(self, robot: Robot):
|
|
self._robot = robot
|
|
self._lock = Lock()
|
|
|
|
def get_observation(self) -> dict[str, Any]:
|
|
with self._lock:
|
|
return self._robot.get_observation()
|
|
|
|
def send_action(self, action: dict) -> None:
|
|
with self._lock:
|
|
self._robot.send_action(action)
|
|
|
|
@property
|
|
def observation_features(self) -> dict:
|
|
return self._robot.observation_features
|
|
|
|
@property
|
|
def action_features(self) -> dict:
|
|
return self._robot.action_features
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._robot.name
|
|
|
|
@property
|
|
def robot_type(self) -> str:
|
|
return self._robot.robot_type
|
|
|
|
@property
|
|
def cameras(self):
|
|
return getattr(self._robot, "cameras", {})
|
|
|
|
|
|
def _set_openarm_max_relative_target_if_missing(
|
|
robot_cfg: RobotConfig, max_relative_target: float = 8.0
|
|
) -> None:
|
|
"""Set a safe default max_relative_target for OpenArm followers when not provided."""
|
|
if isinstance(robot_cfg, BiOpenArmFollowerConfig):
|
|
if robot_cfg.left_arm_config.max_relative_target is None:
|
|
robot_cfg.left_arm_config.max_relative_target = max_relative_target
|
|
if robot_cfg.right_arm_config.max_relative_target is None:
|
|
robot_cfg.right_arm_config.max_relative_target = max_relative_target
|
|
|
|
|
|
def _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute: torch.Tensor,
|
|
current_state: torch.Tensor,
|
|
relative_step: RelativeActionsProcessorStep | None,
|
|
normalizer_step: NormalizerProcessorStep | None,
|
|
policy_device: torch.device | str,
|
|
) -> torch.Tensor:
|
|
"""Convert absolute leftovers into model space for relative-action RTC policies."""
|
|
if relative_step is None:
|
|
return prev_actions_absolute.to(policy_device)
|
|
|
|
state = current_state.detach().cpu()
|
|
if state.dim() == 1:
|
|
state = state.unsqueeze(0)
|
|
|
|
action_cpu = prev_actions_absolute.detach().cpu()
|
|
mask = relative_step._build_mask(action_cpu.shape[-1])
|
|
relative_actions = to_relative_actions(action_cpu, state, mask)
|
|
|
|
transition = create_transition(action=relative_actions)
|
|
if normalizer_step is not None:
|
|
transition = normalizer_step(transition)
|
|
|
|
return transition[TransitionKey.ACTION].to(policy_device)
|
|
|
|
|
|
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
|
"""Pad/truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
|
if prev_actions.ndim != 2:
|
|
raise ValueError(f"Expected prev_actions to be 2D [T, A], got shape={tuple(prev_actions.shape)}")
|
|
|
|
steps, action_dim = prev_actions.shape
|
|
if steps == target_steps:
|
|
return prev_actions
|
|
if steps > target_steps:
|
|
return prev_actions[:target_steps]
|
|
|
|
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
|
padded[:steps] = prev_actions
|
|
return padded
|
|
|
|
|
|
def _resolve_action_key_order(cfg, dataset_action_names: list[str]) -> list[str]:
|
|
"""Choose action name ordering used to map policy tensor outputs to robot action dict."""
|
|
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
|
if not policy_action_names:
|
|
return dataset_action_names
|
|
|
|
policy_action_names = list(policy_action_names)
|
|
if len(policy_action_names) != len(dataset_action_names):
|
|
logger.warning(
|
|
"[RTC] policy.action_feature_names length (%d) != dataset action dim (%d); "
|
|
"falling back to dataset order",
|
|
len(policy_action_names),
|
|
len(dataset_action_names),
|
|
)
|
|
return dataset_action_names
|
|
|
|
if set(dataset_action_names) != set(policy_action_names):
|
|
logger.warning(
|
|
"[RTC] policy.action_feature_names keys do not match dataset action keys; "
|
|
"falling back to dataset order"
|
|
)
|
|
return dataset_action_names
|
|
|
|
return policy_action_names
|
|
|
|
|
|
def _resolve_state_joint_order(
|
|
policy_action_names: list[str] | None,
|
|
available_joint_names: list[str],
|
|
) -> list[str]:
|
|
"""Resolve joint-state ordering used to build observation.state."""
|
|
if not policy_action_names:
|
|
return available_joint_names
|
|
|
|
policy_action_names = list(policy_action_names)
|
|
available_set = set(available_joint_names)
|
|
policy_set = set(policy_action_names)
|
|
|
|
if len(policy_action_names) != len(available_joint_names) or policy_set != available_set:
|
|
logger.warning(
|
|
"policy.action_feature_names does not match available state joints; "
|
|
"falling back to robot observation order"
|
|
)
|
|
return available_joint_names
|
|
|
|
logger.info("Using policy.action_feature_names order for observation.state mapping")
|
|
return policy_action_names
|
|
|
|
|
|
def _start_pedal_listener(events: dict):
|
|
"""Start foot pedal listener thread if evdev is available.
|
|
|
|
Pedal input is restricted to HIL control handoff only:
|
|
policy -> pause -> takeover -> resume policy.
|
|
Episode save/advance remains keyboard-only (right arrow).
|
|
"""
|
|
import threading
|
|
|
|
try:
|
|
from evdev import InputDevice, categorize, ecodes
|
|
except ImportError:
|
|
logging.warning("[Pedal] evdev not installed - pedal support disabled")
|
|
return
|
|
|
|
pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
|
key_left = "KEY_A"
|
|
key_right = "KEY_C"
|
|
|
|
def pedal_reader():
|
|
try:
|
|
dev = InputDevice(pedal_device)
|
|
logger.info(f"[Pedal] Connected: {dev.name}")
|
|
|
|
for ev in dev.read_loop():
|
|
if ev.type != ecodes.EV_KEY:
|
|
continue
|
|
|
|
key = categorize(ev)
|
|
code = key.keycode
|
|
if isinstance(code, (list, tuple)):
|
|
code = code[0]
|
|
|
|
if key.keystate != 1:
|
|
continue
|
|
|
|
if events["in_reset"]:
|
|
if code in [key_left, key_right]:
|
|
events["start_next_episode"] = True
|
|
else:
|
|
if code not in [key_left, key_right]:
|
|
continue
|
|
|
|
if events["correction_active"]:
|
|
events["resume_policy"] = True
|
|
elif events["policy_paused"]:
|
|
events["start_next_episode"] = True
|
|
else:
|
|
events["policy_paused"] = True
|
|
|
|
except FileNotFoundError:
|
|
logging.info(f"[Pedal] Device not found: {pedal_device}")
|
|
except PermissionError:
|
|
logging.warning(f"[Pedal] Permission denied for {pedal_device}")
|
|
except Exception as e:
|
|
logging.warning(f"[Pedal] Error: {e}")
|
|
|
|
thread = threading.Thread(target=pedal_reader, daemon=True)
|
|
thread.start()
|
|
|
|
|
|
def _rtc_inference_thread(
|
|
policy: PreTrainedPolicy,
|
|
obs_holder: dict,
|
|
obs_lock: Lock,
|
|
hw_features: dict,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
queue_holder: dict,
|
|
shutdown_event: Event,
|
|
policy_active: Event,
|
|
compile_warmup_done: Event,
|
|
cfg,
|
|
):
|
|
"""Background thread for RTC action chunk generation."""
|
|
latency_tracker = LatencyTracker()
|
|
time_per_chunk = 1.0 / cfg.dataset.fps
|
|
threshold = 30
|
|
policy_device = policy.config.device
|
|
stats_window_start = time.perf_counter()
|
|
policy_inference_count = 0
|
|
latency_sum_s = 0.0
|
|
inference_count = 0
|
|
warmup_required = max(1, int(cfg.compile_warmup_inferences)) if cfg.use_torch_compile else 0
|
|
|
|
relative_step = next(
|
|
(
|
|
step
|
|
for step in preprocessor.steps
|
|
if isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
|
),
|
|
None,
|
|
)
|
|
normalizer_step = next(
|
|
(step for step in preprocessor.steps if isinstance(step, NormalizerProcessorStep)),
|
|
None,
|
|
)
|
|
if relative_step is not None:
|
|
if relative_step.action_names is None:
|
|
cfg_action_names = getattr(cfg.policy, "action_feature_names", None)
|
|
if cfg_action_names:
|
|
relative_step.action_names = list(cfg_action_names)
|
|
else:
|
|
fallback_action_names = obs_holder.get("action_feature_names")
|
|
if fallback_action_names:
|
|
relative_step.action_names = list(fallback_action_names)
|
|
logger.info("[RTC] Relative actions enabled: re-anchoring RTC prefix to current state")
|
|
|
|
while not shutdown_event.is_set():
|
|
if not policy_active.is_set():
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
queue = queue_holder.get("queue")
|
|
with obs_lock:
|
|
obs = obs_holder.get("obs")
|
|
if queue is None or obs is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
if queue.qsize() <= threshold:
|
|
try:
|
|
current_time = time.perf_counter()
|
|
idx_before = queue.get_action_index()
|
|
prev_actions = queue.get_left_over()
|
|
|
|
latency = latency_tracker.max()
|
|
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
|
|
|
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
|
|
for name in obs_batch:
|
|
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
|
if "image" in name:
|
|
obs_batch[name] = obs_batch[name].float() / 255
|
|
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
|
obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device)
|
|
|
|
obs_batch["task"] = [cfg.dataset.single_task]
|
|
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
|
|
|
|
preprocessed = preprocessor(obs_batch)
|
|
|
|
if prev_actions is not None and relative_step is not None and OBS_STATE in obs_batch:
|
|
prev_actions_absolute = queue.get_processed_left_over()
|
|
if prev_actions_absolute is not None and prev_actions_absolute.numel() > 0:
|
|
prev_actions = _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute=prev_actions_absolute,
|
|
current_state=obs_batch[OBS_STATE],
|
|
relative_step=relative_step,
|
|
normalizer_step=normalizer_step,
|
|
policy_device=policy_device,
|
|
)
|
|
|
|
if prev_actions is not None:
|
|
prev_actions = _normalize_prev_actions_length(
|
|
prev_actions, target_steps=cfg.rtc.execution_horizon
|
|
)
|
|
|
|
actions = policy.predict_action_chunk(
|
|
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
|
)
|
|
|
|
original = actions.squeeze(0).clone()
|
|
processed = postprocessor(actions).squeeze(0)
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
inference_count += 1
|
|
is_warmup_inference = cfg.use_torch_compile and inference_count <= warmup_required
|
|
if is_warmup_inference:
|
|
latency_tracker.reset()
|
|
else:
|
|
latency_tracker.add(new_latency)
|
|
queue.merge(original, processed, new_delay, idx_before)
|
|
policy_inference_count += 1
|
|
latency_sum_s += new_latency
|
|
if (
|
|
is_warmup_inference
|
|
and inference_count >= warmup_required
|
|
and not compile_warmup_done.is_set()
|
|
):
|
|
compile_warmup_done.set()
|
|
logger.info(
|
|
"[RTC] Compile warmup complete (%d/%d inferences)",
|
|
inference_count,
|
|
warmup_required,
|
|
)
|
|
logger.debug("[RTC] Inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
|
except Exception as e:
|
|
logger.error("[RTC] Error: %s", e)
|
|
time.sleep(0.5)
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
now = time.perf_counter()
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
policy_hz = policy_inference_count / window_elapsed
|
|
avg_latency_ms = (
|
|
(latency_sum_s / policy_inference_count * 1000.0) if policy_inference_count else 0.0
|
|
)
|
|
logger.info(
|
|
"[HIL RTC rates] policy=%.1f Hz | avg_inference=%.1f ms | queue=%d",
|
|
policy_hz,
|
|
avg_latency_ms,
|
|
queue.qsize(),
|
|
)
|
|
stats_window_start = now
|
|
policy_inference_count = 0
|
|
latency_sum_s = 0.0
|
|
|
|
|
|
# Config
|
|
|
|
|
|
@dataclass
|
|
class HILConfig:
|
|
robot: RobotConfig
|
|
teleop: TeleoperatorConfig
|
|
dataset: HILDatasetConfig
|
|
policy: PreTrainedConfig | None = None
|
|
rtc: RTCConfig = field(default_factory=RTCConfig)
|
|
interpolation_multiplier: int = 2
|
|
record_interpolated_actions: bool = False
|
|
display_data: bool = True
|
|
play_sounds: bool = True
|
|
resume: bool = False
|
|
device: str = "cuda"
|
|
use_torch_compile: bool = False
|
|
compile_warmup_inferences: int = 2
|
|
calibrate: bool = False
|
|
log_hz: bool = True
|
|
hz_log_interval_s: float = 2.0
|
|
|
|
def __post_init__(self):
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
if self.policy is None:
|
|
raise ValueError("policy.path is required")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
return ["policy"]
|
|
|
|
|
|
# Rollout loops
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def _rollout_sync(
|
|
robot: Robot,
|
|
teleop: Teleoperator,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
dataset: LeRobotDataset,
|
|
events: dict,
|
|
cfg: HILConfig,
|
|
):
|
|
"""Rollout loop with standard synchronous inference."""
|
|
fps = cfg.dataset.fps
|
|
device = get_safe_torch_device(cfg.device)
|
|
stream_online = bool(cfg.dataset.streaming_encoding)
|
|
record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier)
|
|
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
frame_buffer: list[dict] = []
|
|
teleop_disable_torque(teleop)
|
|
|
|
was_paused = False
|
|
waiting_for_takeover = False
|
|
last_action: dict[str, Any] | None = None
|
|
robot_action: dict[str, Any] = {}
|
|
action_keys = list(dataset.features[ACTION]["names"])
|
|
obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"])
|
|
obs_image_names = [
|
|
key.removeprefix(f"{OBS_STR}.images.")
|
|
for key in dataset.features
|
|
if key.startswith(f"{OBS_STR}.images.")
|
|
]
|
|
|
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
control_interval = interpolator.get_control_interval(fps)
|
|
|
|
timestamp = 0.0
|
|
record_tick = 0
|
|
start_t = time.perf_counter()
|
|
stats_window_start = start_t
|
|
policy_inference_count = 0
|
|
robot_command_count = 0
|
|
|
|
while timestamp < cfg.dataset.episode_time_s:
|
|
loop_start = time.perf_counter()
|
|
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
events["resume_policy"] = False
|
|
break
|
|
|
|
if events["resume_policy"] and (
|
|
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
|
|
):
|
|
events["resume_policy"] = False
|
|
events["start_next_episode"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
waiting_for_takeover = False
|
|
was_paused = False
|
|
last_action = None
|
|
interpolator.reset()
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
if events["policy_paused"] and not was_paused:
|
|
obs = robot.get_observation()
|
|
robot_pos = {
|
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
|
}
|
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
events["start_next_episode"] = False
|
|
waiting_for_takeover = True
|
|
was_paused = True
|
|
interpolator.reset()
|
|
|
|
if waiting_for_takeover and events["start_next_episode"]:
|
|
teleop_disable_torque(teleop)
|
|
events["start_next_episode"] = False
|
|
events["correction_active"] = True
|
|
waiting_for_takeover = False
|
|
|
|
obs = robot.get_observation()
|
|
obs_filtered = {k: obs[k] for k in obs_state_names if k in obs}
|
|
obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs})
|
|
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
|
|
|
if events["correction_active"]:
|
|
robot_action = teleop.get_action()
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
elif waiting_for_takeover or events["policy_paused"]:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
robot_command_count += 1
|
|
|
|
else:
|
|
if interpolator.needs_new_action():
|
|
action_values = predict_action(
|
|
observation=obs_frame,
|
|
policy=policy,
|
|
device=device,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
use_amp=policy.config.use_amp,
|
|
task=cfg.dataset.single_task,
|
|
robot_type=robot.robot_type,
|
|
)
|
|
policy_inference_count += 1
|
|
robot_action = make_robot_action(action_values, dataset.features)
|
|
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
|
|
interpolator.add(action_tensor)
|
|
|
|
interp_action = interpolator.get()
|
|
if interp_action is not None:
|
|
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
last_action = robot_action
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
if cfg.display_data and robot_action:
|
|
log_rerun_data(observation=obs_filtered, action=robot_action)
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_time := control_interval - dt) > 0:
|
|
precise_sleep(sleep_time)
|
|
now = time.perf_counter()
|
|
timestamp = now - start_t
|
|
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
policy_hz = policy_inference_count / window_elapsed
|
|
robot_hz = robot_command_count / window_elapsed
|
|
logger.info(
|
|
"[HIL rates] policy=%.1f Hz (target=%.1f) | robot=%.1f Hz (target=%.1f)",
|
|
policy_hz,
|
|
fps,
|
|
robot_hz,
|
|
fps * cfg.interpolation_multiplier,
|
|
)
|
|
stats_window_start = now
|
|
policy_inference_count = 0
|
|
robot_command_count = 0
|
|
|
|
teleop_disable_torque(teleop)
|
|
|
|
if not stream_online:
|
|
for frame in frame_buffer:
|
|
dataset.add_frame(frame)
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def _rollout_rtc(
|
|
robot,
|
|
teleop: Teleoperator,
|
|
policy: PreTrainedPolicy,
|
|
preprocessor: PolicyProcessorPipeline,
|
|
postprocessor: PolicyProcessorPipeline,
|
|
dataset: LeRobotDataset,
|
|
events: dict,
|
|
cfg: HILConfig,
|
|
queue_holder: dict,
|
|
obs_holder: dict,
|
|
obs_lock: Lock,
|
|
policy_active: Event,
|
|
compile_warmup_done: Event,
|
|
hw_features: dict,
|
|
):
|
|
"""Rollout loop with RTC for asynchronous inference."""
|
|
fps = cfg.dataset.fps
|
|
stream_online = bool(cfg.dataset.streaming_encoding)
|
|
record_stride = 1 if cfg.record_interpolated_actions else max(1, cfg.interpolation_multiplier)
|
|
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
frame_buffer: list[dict] = []
|
|
teleop_disable_torque(teleop)
|
|
|
|
was_paused = False
|
|
waiting_for_takeover = False
|
|
last_action: dict[str, Any] | None = None
|
|
dataset_action_keys = list(dataset.features[ACTION]["names"])
|
|
action_keys = _resolve_action_key_order(cfg, dataset_action_keys)
|
|
if action_keys != dataset_action_keys:
|
|
logger.info("[RTC] Using policy.action_feature_names order for action tensor mapping")
|
|
else:
|
|
logger.info("[RTC] Using dataset action feature order for action tensor mapping")
|
|
obs_state_names = list(dataset.features[f"{OBS_STR}.state"]["names"])
|
|
obs_image_names = [
|
|
key.removeprefix(f"{OBS_STR}.images.")
|
|
for key in dataset.features
|
|
if key.startswith(f"{OBS_STR}.images.")
|
|
]
|
|
|
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
control_interval = interpolator.get_control_interval(fps)
|
|
|
|
robot_action: dict[str, Any] = {}
|
|
timestamp = 0.0
|
|
start_t = time.perf_counter()
|
|
stats_window_start = start_t
|
|
robot_command_count = 0
|
|
record_tick = 0
|
|
obs_poll_interval = 1.0 / fps
|
|
last_obs_poll_t = 0.0
|
|
obs_filtered: dict[str, Any] = {}
|
|
obs_frame: dict[str, Any] = {}
|
|
warmup_wait_logged = False
|
|
warmup_queue_flushed = False
|
|
|
|
while timestamp < cfg.dataset.episode_time_s:
|
|
loop_start = time.perf_counter()
|
|
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
events["resume_policy"] = False
|
|
break
|
|
|
|
if events["resume_policy"] and (
|
|
events["policy_paused"] or events["correction_active"] or waiting_for_takeover
|
|
):
|
|
events["resume_policy"] = False
|
|
events["start_next_episode"] = False
|
|
events["policy_paused"] = False
|
|
events["correction_active"] = False
|
|
waiting_for_takeover = False
|
|
was_paused = False
|
|
last_action = None
|
|
interpolator.reset()
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
policy_active.clear()
|
|
policy.reset()
|
|
preprocessor.reset()
|
|
postprocessor.reset()
|
|
|
|
if events["policy_paused"] and not was_paused:
|
|
policy_active.clear()
|
|
obs = robot.get_observation()
|
|
robot_pos = {
|
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
|
}
|
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
events["start_next_episode"] = False
|
|
waiting_for_takeover = True
|
|
was_paused = True
|
|
interpolator.reset()
|
|
|
|
if waiting_for_takeover and events["start_next_episode"]:
|
|
teleop_disable_torque(teleop)
|
|
events["start_next_episode"] = False
|
|
events["correction_active"] = True
|
|
waiting_for_takeover = False
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
|
|
now_for_obs = time.perf_counter()
|
|
should_poll_obs = (
|
|
not obs_filtered
|
|
or (now_for_obs - last_obs_poll_t) >= obs_poll_interval
|
|
or events["correction_active"]
|
|
or waiting_for_takeover
|
|
or events["policy_paused"]
|
|
)
|
|
if should_poll_obs:
|
|
obs = robot.get_observation()
|
|
obs_filtered = {k: obs[k] for k in obs_state_names if k in obs}
|
|
obs_filtered.update({k: obs[k] for k in obs_image_names if k in obs})
|
|
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
|
with obs_lock:
|
|
obs_holder["obs"] = obs_filtered
|
|
last_obs_poll_t = now_for_obs
|
|
|
|
if events["correction_active"]:
|
|
robot_action = teleop.get_action()
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
elif waiting_for_takeover or events["policy_paused"]:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
robot_command_count += 1
|
|
|
|
else:
|
|
if not policy_active.is_set():
|
|
policy_active.set()
|
|
|
|
if cfg.use_torch_compile and not compile_warmup_done.is_set():
|
|
if not warmup_wait_logged:
|
|
logger.info(
|
|
"[RTC] Waiting for compile warmup (%d inferences) before policy rollout",
|
|
max(1, int(cfg.compile_warmup_inferences)),
|
|
)
|
|
warmup_wait_logged = True
|
|
else:
|
|
if cfg.use_torch_compile and not warmup_queue_flushed:
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
interpolator.reset()
|
|
warmup_queue_flushed = True
|
|
logger.info("[RTC] Warmup queue cleared; starting live policy rollout")
|
|
|
|
queue = queue_holder["queue"]
|
|
|
|
if interpolator.needs_new_action():
|
|
new_action = queue.get() if queue else None
|
|
if new_action is not None:
|
|
interpolator.add(new_action.cpu())
|
|
|
|
action_tensor = interpolator.get()
|
|
if action_tensor is not None:
|
|
robot_action = {
|
|
k: action_tensor[i].item()
|
|
for i, k in enumerate(action_keys)
|
|
if i < len(action_tensor)
|
|
}
|
|
robot.send_action(robot_action)
|
|
robot_command_count += 1
|
|
last_action = robot_action
|
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
|
if stream_online:
|
|
dataset.add_frame(frame)
|
|
else:
|
|
frame_buffer.append(frame)
|
|
record_tick += 1
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_time := control_interval - dt) > 0:
|
|
precise_sleep(sleep_time)
|
|
now = time.perf_counter()
|
|
timestamp = now - start_t
|
|
|
|
if cfg.log_hz and (window_elapsed := now - stats_window_start) >= cfg.hz_log_interval_s:
|
|
robot_hz = robot_command_count / window_elapsed
|
|
logger.info(
|
|
"[HIL RTC rates] robot=%.1f Hz (target=%.1f)",
|
|
robot_hz,
|
|
fps * cfg.interpolation_multiplier,
|
|
)
|
|
stats_window_start = now
|
|
robot_command_count = 0
|
|
|
|
policy_active.clear()
|
|
teleop_disable_torque(teleop)
|
|
|
|
if not stream_online:
|
|
for frame in frame_buffer:
|
|
dataset.add_frame(frame)
|
|
|
|
|
|
# Main collection function
|
|
|
|
|
|
@parser.wrap()
|
|
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
|
"""Main HIL data collection function (supports both sync and RTC modes)."""
|
|
init_logging()
|
|
logger.info(pformat(cfg.__dict__))
|
|
|
|
use_rtc = cfg.rtc.enabled
|
|
|
|
if use_rtc:
|
|
_set_openarm_max_relative_target_if_missing(cfg.robot, max_relative_target=8.0)
|
|
|
|
if cfg.display_data:
|
|
init_rerun(session_name="hil_collection")
|
|
|
|
robot_raw = make_robot_from_config(cfg.robot)
|
|
teleop = make_teleoperator_from_config(cfg.teleop)
|
|
|
|
teleop_proc, obs_proc = make_identity_processors()
|
|
|
|
action_features_hw = {k: v for k, v in robot_raw.action_features.items() if k.endswith(".pos")}
|
|
all_observation_features = robot_raw.observation_features
|
|
available_joint_names = [
|
|
key for key, value in all_observation_features.items() if key.endswith(".pos") and value is float
|
|
]
|
|
ordered_joint_names = _resolve_state_joint_order(
|
|
getattr(cfg.policy, "action_feature_names", None),
|
|
available_joint_names,
|
|
)
|
|
observation_features_hw = {
|
|
joint_name: all_observation_features[joint_name] for joint_name in ordered_joint_names
|
|
}
|
|
for key, value in all_observation_features.items():
|
|
if isinstance(value, tuple):
|
|
observation_features_hw[key] = value
|
|
|
|
dataset_features = combine_feature_dicts(
|
|
aggregate_pipeline_dataset_features(
|
|
pipeline=teleop_proc,
|
|
initial_features=create_initial_features(action=action_features_hw),
|
|
use_videos=cfg.dataset.video,
|
|
),
|
|
aggregate_pipeline_dataset_features(
|
|
pipeline=obs_proc,
|
|
initial_features=create_initial_features(observation=observation_features_hw),
|
|
use_videos=cfg.dataset.video,
|
|
),
|
|
)
|
|
|
|
dataset = None
|
|
listener = None
|
|
shutdown_event = Event()
|
|
policy_active = Event()
|
|
compile_warmup_done = Event()
|
|
if not cfg.use_torch_compile:
|
|
compile_warmup_done.set()
|
|
rtc_thread = None
|
|
|
|
try:
|
|
if cfg.resume:
|
|
dataset = LeRobotDataset(
|
|
cfg.dataset.repo_id,
|
|
root=cfg.dataset.root,
|
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
|
vcodec=cfg.dataset.vcodec,
|
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
|
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
|
encoder_threads=cfg.dataset.encoder_threads,
|
|
)
|
|
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
|
dataset.start_image_writer(
|
|
num_processes=cfg.dataset.num_image_writer_processes,
|
|
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
|
)
|
|
else:
|
|
dataset = LeRobotDataset.create(
|
|
cfg.dataset.repo_id,
|
|
cfg.dataset.fps,
|
|
root=cfg.dataset.root,
|
|
robot_type=robot_raw.name,
|
|
features=dataset_features,
|
|
use_videos=cfg.dataset.video,
|
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
|
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
|
vcodec=cfg.dataset.vcodec,
|
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
|
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
|
encoder_threads=cfg.dataset.encoder_threads,
|
|
)
|
|
|
|
# Load policy — RTC needs manual loading for predict_action_chunk support
|
|
if use_rtc:
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
if hasattr(policy_config, "compile_model"):
|
|
policy_config.compile_model = cfg.use_torch_compile
|
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
|
policy.config.rtc_config = cfg.rtc
|
|
if hasattr(policy, "init_rtc_processor"):
|
|
policy.init_rtc_processor()
|
|
policy = policy.to(cfg.device)
|
|
policy.eval()
|
|
else:
|
|
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": cfg.device},
|
|
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
|
},
|
|
)
|
|
|
|
# Connect hardware
|
|
if use_rtc:
|
|
logger.info("Connecting robot (calibrate=%s)", cfg.calibrate)
|
|
robot_raw.connect(calibrate=False)
|
|
if cfg.calibrate and hasattr(robot_raw, "calibrate"):
|
|
robot_raw.calibrate()
|
|
robot_raw.disconnect()
|
|
robot_raw.connect(calibrate=False)
|
|
else:
|
|
robot_raw.connect()
|
|
|
|
robot = ThreadSafeRobot(robot_raw) if use_rtc else robot_raw
|
|
teleop.connect()
|
|
listener, events = init_keyboard_listener()
|
|
|
|
# RTC-specific setup
|
|
queue_holder = None
|
|
obs_holder = None
|
|
obs_lock = Lock()
|
|
hw_features = None
|
|
if use_rtc:
|
|
_start_pedal_listener(events)
|
|
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
|
obs_holder = {
|
|
"obs": None,
|
|
"robot_type": robot.robot_type,
|
|
"action_feature_names": [key for key in robot.action_features if key.endswith(".pos")],
|
|
}
|
|
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
|
|
|
rtc_thread = Thread(
|
|
target=_rtc_inference_thread,
|
|
args=(
|
|
policy,
|
|
obs_holder,
|
|
obs_lock,
|
|
hw_features,
|
|
preprocessor,
|
|
postprocessor,
|
|
queue_holder,
|
|
shutdown_event,
|
|
policy_active,
|
|
compile_warmup_done,
|
|
cfg,
|
|
),
|
|
daemon=True,
|
|
)
|
|
rtc_thread.start()
|
|
|
|
print_controls(rtc=use_rtc)
|
|
logger.info(f" Policy: {cfg.policy.pretrained_path}")
|
|
logger.info(f" Task: {cfg.dataset.single_task}")
|
|
logger.info(f" Interpolation: {cfg.interpolation_multiplier}x")
|
|
if use_rtc:
|
|
logger.info(f" RTC: enabled (execution_horizon={cfg.rtc.execution_horizon})")
|
|
|
|
with VideoEncodingManager(dataset):
|
|
recorded = 0
|
|
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
|
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
|
|
|
if use_rtc:
|
|
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
|
_rollout_rtc(
|
|
robot=robot,
|
|
teleop=teleop,
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset=dataset,
|
|
events=events,
|
|
cfg=cfg,
|
|
queue_holder=queue_holder,
|
|
obs_holder=obs_holder,
|
|
obs_lock=obs_lock,
|
|
policy_active=policy_active,
|
|
compile_warmup_done=compile_warmup_done,
|
|
hw_features=hw_features,
|
|
)
|
|
else:
|
|
_rollout_sync(
|
|
robot=robot,
|
|
teleop=teleop,
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset=dataset,
|
|
events=events,
|
|
cfg=cfg,
|
|
)
|
|
|
|
if events["rerecord_episode"]:
|
|
log_say("Re-recording", cfg.play_sounds)
|
|
events["rerecord_episode"] = False
|
|
events["exit_early"] = False
|
|
dataset.clear_episode_buffer()
|
|
continue
|
|
|
|
dataset.save_episode()
|
|
recorded += 1
|
|
|
|
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
|
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
|
|
|
finally:
|
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
|
|
|
shutdown_event.set()
|
|
policy_active.clear()
|
|
|
|
if rtc_thread and rtc_thread.is_alive():
|
|
rtc_thread.join(timeout=2.0)
|
|
|
|
if dataset:
|
|
dataset.finalize()
|
|
|
|
if robot_raw.is_connected:
|
|
robot_raw.disconnect()
|
|
if teleop.is_connected:
|
|
teleop.disconnect()
|
|
|
|
if not is_headless() and listener:
|
|
listener.stop()
|
|
|
|
if cfg.dataset.push_to_hub and dataset is not None:
|
|
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
|
|
|
return dataset
|
|
|
|
|
|
def main():
|
|
from lerobot.utils.import_utils import register_third_party_plugins
|
|
|
|
register_third_party_plugins()
|
|
hil_collect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|