Compare commits

...

24 Commits

Author SHA1 Message Date
Steven Palma 8e21268c29 test: add dataset guard + fix imports 2026-04-20 00:36:02 +02:00
Steven Palma 4130d4a4a5 update docs + docstrings + examples + add minimal test 2026-04-19 23:53:53 +02:00
Steven Palma 47bb840a55 add context guards 2026-04-19 23:21:14 +02:00
Steven Palma 9519ff5e09 Merge branch 'main' into feat/decouple_record_script
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-04-19 22:48:08 +02:00
Steven Palma 32a27cae8a filesize default change + more logs + filesize mb based episode + go back to init pos + rerun log + date end of repo_id 2026-04-19 16:50:19 +02:00
Steven Palma 8cee56e2d6 fix pre-commit 2026-04-17 16:46:58 +02:00
Steven Palma a76874f35e test dagger 2026-04-17 16:46:38 +02:00
Steven Palma 35bb2c7459 simplify dagger 2026-04-17 15:55:03 +02:00
Steven Palma 051f6c6803 Merge branch 'main' into feat/decouple_record_script 2026-04-17 14:25:18 +02:00
Steven Palma 04ae0312a2 HW tests fixes 2026-04-16 17:29:22 +02:00
Steven Palma cc634de9e7 add docstrings 2026-04-16 16:40:33 +02:00
Steven Palma 3eda5712d3 some more iterations 2026-04-16 15:52:23 +02:00
Steven Palma 783ec6e232 minor improvements 2026-04-16 14:34:22 +02:00
Steven Palma 4e3175ff15 address review 2026-04-15 19:31:53 +02:00
Steven Palma edd7fc52a8 feat: introduce inference engine strategy 2026-04-15 17:51:44 +02:00
Steven Palma 0f0f8b8961 imports and comments 2026-04-15 16:28:56 +02:00
Steven Palma 79db54dc34 Merge branch 'main' into feat/decouple_record_script 2026-04-15 11:06:45 +02:00
Steven Palma 6ae07878f7 Merge branch 'main' into feat/decouple_record_script 2026-04-14 22:54:29 +02:00
Steven Palma 10d05e03bc Merge branch 'main' into feat/decouple_record_script 2026-04-14 21:35:26 +02:00
Steven Palma f2c29d78cf more improvements and fixes 2026-04-14 17:51:03 +02:00
Steven Palma 8bc47e4318 target review 2026-04-14 17:14:09 +02:00
Steven Palma 49f32b9796 some more iterations 2026-04-14 16:34:52 +02:00
Steven Palma f55782f9f7 pre-commit run 2026-04-14 15:42:19 +02:00
Steven Palma 05a2604d6e first iteration 2026-04-14 15:42:04 +02:00
54 changed files with 5204 additions and 2816 deletions
+2
View File
@@ -61,6 +61,8 @@
title: SARM
title: "Reward Models"
- sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async
title: Use Async Inference
- local: rtc
+17 -19
View File
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
### Teleoperator Requirements
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators in the current `examples/hil` scripts:**
**Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
## Script
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
| Mode | Flag | Models |
| ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
---
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/hil/hil_data_collection.py \
lerobot-rollout --strategy.type=dagger \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -111,8 +111,7 @@ python examples/hil/hil_data_collection.py \
--dataset.repo_id=your-username/hil-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--strategy.num_episodes=50 \
--interpolation_multiplier=2
```
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
For models with high inference latency, enable RTC for smooth execution:
```bash
python examples/hil/hil_data_collection.py \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
lerobot-rollout --strategy.type=dagger \
--inference.type=rtc \
--inference.rtc.execution_horizon=20 \
--inference.rtc.max_guidance_weight=5.0 \
--inference.rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -139,8 +138,7 @@ python examples/hil/hil_data_collection.py \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--strategy.num_episodes=50 \
--interpolation_multiplier=3
```
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
+26 -105
View File
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
<hfoptions id="eval">
<hfoption id="Command">
<hfoption id="Base mode (no recording)">
```bash
lerobot-record \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--robot.id=my_awesome_follower_arm \
--display_data=false \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_awesome_leader_arm \
--policy.path=${HF_USER}/my_policy
--task="Put lego brick into the transparent box" \
--duration=60
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.policies.act import ACTPolicy
from lerobot.policies import make_pre_post_processors
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Run the policy inference loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
# Clean up
robot.disconnect()
dataset.push_to_hub()
<hfoption id="Sentry mode (with recording)">
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--duration=600
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
The `--strategy.type` flag selects the execution mode:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+261
View File
@@ -0,0 +1,261 @@
# Policy Deployment (lerobot-rollout)
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
## Quick Start
No extra dependencies are needed beyond your robot and policy extras.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=lerobot/act_koch_real \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--task="pick up cube" \
--duration=30
```
This runs the policy for 30 seconds with no recording.
---
## Strategies
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
### Base (`--strategy.type=base`)
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the box" \
--duration=60
```
| Flag | Description |
| ---------------- | ------------------------------------------------------ |
| `--duration` | Run time in seconds (0 = infinite) |
| `--task` | Task description passed to the policy |
| `--display_data` | Stream observations/actions to Rerun for visualization |
### Sentry (`--strategy.type=sentry`)
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_data \
--dataset.single_task="Put lego brick into the box" \
--duration=3600
```
| Flag | Description |
| -------------------------------------- | ----------------------------------------------------------- |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
### Highlight (`--strategy.type=highlight`)
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
```bash
lerobot-rollout \
--strategy.type=highlight \
--strategy.ring_buffer_seconds=30 \
--strategy.save_key=s \
--strategy.push_key=h \
--policy.path=${HF_USER}/my_policy \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--dataset.repo_id=${HF_USER}/highlight_data \
--dataset.single_task="Pick up the red cube"
```
**Keyboard controls:**
| Key | Action |
| ------------------ | -------------------------------------------------------- |
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
| `h` (configurable) | Push dataset to Hub |
| `ESC` | Stop the session |
| Flag | Description |
| -------------------------------------- | ---------------------------------------------- |
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
### DAgger (`--strategy.type=dagger`)
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/hil_data \
--dataset.single_task="Fold the T-shirt"
```
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.record_autonomous=true \
--strategy.num_episodes=50 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/dagger_data \
--dataset.single_task="Grasp the block"
```
**Keyboard controls** (default input device):
| Key | Action |
| ------- | ------------------------------------------- |
| `Space` | Pause / resume policy execution |
| `Tab` | Start / stop human correction |
| `Enter` | Push dataset to Hub (corrections-only mode) |
| `ESC` | Stop the session |
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
| Flag | Description |
| ------------------------------------ | ------------------------------------------------------- |
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
---
## Inference Backends
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
### Sync (default)
One policy call per control tick. The main loop blocks until the action is computed.
Works with all policies. No extra flags needed.
### Real-Time Chunking (`--inference.type=rtc`)
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
```bash
lerobot-rollout \
--strategy.type=base \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--policy.path=${HF_USER}/pi0_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Pick up the cube" \
--duration=60 \
--device=cuda
```
| Flag | Description |
| ------------------------------------------- | -------------------------------------------------------------- |
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
---
## Common Flags
| Flag | Description | Default |
| --------------------------------- | ----------------------------------------------------------------- | ------- |
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
| `--robot.port` | Serial port for the robot | -- |
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
| `--fps` | Control loop frequency | 30 |
| `--duration` | Run time in seconds (0 = infinite) | 0 |
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
| `--task` | Task description (used when no dataset is provided) | -- |
| `--display_data` | Stream telemetry to Rerun visualization | false |
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
| `--interpolation_multiplier` | Action interpolation factor | 1 |
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
| `--resume` | Resume a previous recording session | false |
| `--play_sounds` | Vocal synthesis for events | true |
---
## Programmatic Usage
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig(
robot=my_robot_config,
policy=my_policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=30,
duration=60,
task="my task",
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=my_custom_action_processor, # optional
robot_observation_processor=my_custom_obs_processor, # optional
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
```
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
+7 -3
View File
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USERNAME}/policy_repo_id \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
## References
+1 -1
View File
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \
--duration=1000 \
--fps=30 \
--rtc.enabled=true
--inference.type=rtc
```
---
File diff suppressed because it is too large Load Diff
-226
View File
@@ -1,226 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common.control_utils import is_headless
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)
+62 -31
View File
@@ -14,17 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.common.control_utils import init_keyboard_listener
import logging
import time
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 2
FPS = 30
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
@@ -83,43 +90,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot
robot_action_to_send = robot_action_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
+10 -9
View File
@@ -45,9 +45,6 @@ def main():
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
@@ -77,6 +74,10 @@ def main():
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
teleop_action_processor, robot_action_processor, robot_observation_processor = (
make_default_processors()
)
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
@@ -87,14 +88,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
@@ -106,13 +107,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
+77
View File
@@ -0,0 +1,77 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained policy on LeKiwi without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). For a CLI entry point with the same capabilities plus
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
"""
from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
# by ``build_rollout_context`` to reload the full model.
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
# Assemble the rollout config: base strategy (no recording) + sync inference.
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
signal_handler = ProcessSignalHandler(use_threads=True)
# Build the context (connects robot, loads policy, wires the inference strategy).
# No custom processors here — LeKiwi runs on raw joint features.
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+63 -32
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5
FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+13 -13
View File
@@ -65,14 +65,15 @@ def main():
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to EE action
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction
](
@@ -94,7 +95,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
# Build pipeline to convert EE action to joints action (IK).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
@@ -107,7 +108,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to EE observation
# Build pipeline to convert joint observation to EE observation (FK).
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -118,13 +119,12 @@ def main():
to_output=transition_to_observation,
)
# Create the dataset
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
@@ -163,14 +163,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
@@ -182,13 +182,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
+126
View File
@@ -0,0 +1,126 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
with phone teleoperation in EE space, so at deployment we only need the
joint↔EE conversion on the robot side; the phone is not used.
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
(inline policy call). For recording during rollout, switch to Sentry,
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Peek at motor names once to build the kinematic solver.
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
-673
View File
@@ -1,673 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
make_default_robot_action_processor,
make_default_robot_observation_processor,
to_relative_actions,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
if config.use_peft:
from peft import PeftConfig, PeftModel
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+63 -32
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5
FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+15 -17
View File
@@ -62,21 +62,20 @@ def main():
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# Build pipeline to convert follower joints to EE observation
# Build pipeline to convert follower joints to EE observation.
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -87,7 +86,7 @@ def main():
to_output=transition_to_observation,
)
# Build pipeline to convert leader joints to EE action
# Build pipeline to convert leader joints to EE action.
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
ForwardKinematicsJointsToEE(
@@ -98,9 +97,9 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to follower joints
# Build pipeline to convert EE action to follower joints (with safety bounds).
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
steps=[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
@@ -115,13 +114,12 @@ def main():
to_output=transition_to_robot_action,
)
# Create the dataset
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features),
@@ -144,7 +142,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
init_rerun(session_name="recording_so100_ee")
try:
if not leader.is_connected or not follower.is_connected:
@@ -160,14 +158,14 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
@@ -179,13 +177,13 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
+134
View File
@@ -0,0 +1,134 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). The custom observation/action processors convert between
joint space (robot hardware) and end-effector space (policy I/O) via
forward/inverse kinematics.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Kinematic solver: we need the motor-name list, so peek at the robot once.
# (The rollout engine owns the connected instance; we only use this for introspection.)
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
# Joint-space observation → EE-space observation (consumed by the policy).
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# EE-space action (produced by the policy) → joint-space action (sent to robot).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Policy config (full model is loaded inside build_rollout_context).
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
# otherwise skip the joint↔EE conversion and the policy would receive the
# wrong observation/action space.
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+1
View File
@@ -275,6 +275,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
+2
View File
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
"""
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .types import (
@@ -39,6 +40,7 @@ __all__ = [
"PolicyFeature",
"RTCAttentionSchedule",
# Config classes
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"PeftConfig",
+77
View File
@@ -0,0 +1,77 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str = ""
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str = ""
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+2 -2
View File
@@ -71,8 +71,8 @@ class ForwardCompatibilityError(CompatibilityError):
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
+2 -1
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
@@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
+3 -115
View File
@@ -1,116 +1,4 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
from lerobot.utils.action_interpolator import ActionInterpolator
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
__all__ = ["ActionInterpolator"]
+10 -10
View File
@@ -92,10 +92,10 @@ class ActionQueue:
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
with self.lock:
if self.queue is None:
return 0
return len(self.queue) - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
@@ -103,11 +103,10 @@ class ActionQueue:
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
with self.lock:
if self.queue is None:
return True
return len(self.queue) - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
@@ -115,7 +114,8 @@ class ActionQueue:
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
with self.lock:
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
+1 -1
View File
@@ -76,6 +76,7 @@ from lerobot.transport.utils import (
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
@@ -94,7 +95,6 @@ from .gym_manipulator import (
make_robot_env,
step_env_and_process_transition,
)
from .process import ProcessSignalHandler
from .queue import get_last_item_from_queue
# Main entry point
+1 -1
View File
@@ -90,6 +90,7 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
@@ -99,7 +100,6 @@ from lerobot.utils.utils import (
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .process import ProcessSignalHandler
@parser.wrap()
+82
View File
@@ -0,0 +1,82 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Policy deployment engine with pluggable rollout strategies."""
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from .configs import (
BaseStrategyConfig,
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
DatasetRecordConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
SentryStrategyConfig,
)
from .context import (
DatasetContext,
HardwareContext,
PolicyContext,
ProcessorContext,
RolloutContext,
RuntimeContext,
build_rollout_context,
)
from .inference import (
InferenceEngine,
InferenceEngineConfig,
RTCInferenceConfig,
RTCInferenceEngine,
SyncInferenceConfig,
SyncInferenceEngine,
create_inference_engine,
)
from .ring_buffer import RolloutRingBuffer
from .robot_wrapper import ThreadSafeRobot
from .strategies import RolloutStrategy, create_strategy
__all__ = [
"BaseStrategyConfig",
"DAggerKeyboardConfig",
"DAggerPedalConfig",
"DAggerStrategyConfig",
"DatasetContext",
"DatasetRecordConfig",
"HardwareContext",
"HighlightStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
"ProcessorContext",
"RTCInferenceConfig",
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"RuntimeContext",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context",
"create_inference_engine",
"create_strategy",
]
+270
View File
@@ -0,0 +1,270 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration dataclasses for the rollout deployment engine."""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
import draccus
from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from .inference import InferenceEngineConfig, SyncInferenceConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry)
# ---------------------------------------------------------------------------
@dataclass
class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for rollout strategy configurations.
Use ``--strategy.type=<name>`` on the CLI to select a strategy.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@RolloutStrategyConfig.register_subclass("base")
@dataclass
class BaseStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with no data recording."""
pass
@RolloutStrategyConfig.register_subclass("sentry")
@dataclass
class SentryStrategyConfig(RolloutStrategyConfig):
"""Continuous autonomous rollout with always-on recording.
Episode duration is derived from camera resolution, FPS, and
``target_video_file_size_mb`` so that each saved episode produces a
video file that has crossed the target size. This aligns episode
boundaries with the dataset's video file chunking, so each
``push_to_hub`` call uploads complete video files rather than
re-uploading a growing file that hasn't crossed the chunk boundary.
"""
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation. Episodes are
# saved once the estimated video duration would exceed this limit.
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
target_video_file_size_mb: float | None = None
@RolloutStrategyConfig.register_subclass("highlight")
@dataclass
class HighlightStrategyConfig(RolloutStrategyConfig):
"""Autonomous rollout with on-demand recording via ring buffer.
A memory-bounded ring buffer continuously captures telemetry. When
the user presses the save key, the buffer contents are flushed to
the dataset and live recording continues until the key is pressed
again.
"""
ring_buffer_seconds: float = 30.0
ring_buffer_max_memory_mb: float = 2048.0
save_key: str = "s"
push_key: str = "h"
@dataclass
class DAggerKeyboardConfig:
"""Keyboard key bindings for DAgger controls.
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
special key names (``"space"``).
"""
pause_resume: str = "space"
correction: str = "tab"
upload: str = "enter"
@dataclass
class DAggerPedalConfig:
"""Foot pedal configuration for DAgger controls.
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
"""
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
pause_resume: str = "KEY_A"
correction: str = "KEY_B"
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
"""Human-in-the-loop data collection (DAgger / RaC).
Alternates between autonomous policy execution and human intervention.
Intervention frames are tagged with ``intervention=True``.
Input is controlled via either a keyboard or foot pedal, selected by
``input_device``. Each device exposes three actions:
1. **pause_resume** toggle policy execution on/off.
2. **correction** toggle human correction recording.
3. **upload** push dataset to hub on demand (corrections-only mode).
When ``record_autonomous=True`` (default) both autonomous and correction
frames are recorded with size-based episode rotation (same as Sentry)
and background uploading. ``push_to_hub`` is blocked while a correction
is in progress. Set to ``False`` to record only the human-correction
windows, where each correction becomes its own episode.
"""
num_episodes: int = 10
record_autonomous: bool = False
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation (record_autonomous
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
target_video_file_size_mb: float | None = None
input_device: str = "keyboard"
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
def __post_init__(self):
if self.input_device not in ("keyboard", "pedal"):
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
# ---------------------------------------------------------------------------
# Top-level rollout config
# ---------------------------------------------------------------------------
@dataclass
class RolloutConfig:
"""Top-level configuration for the ``lerobot-rollout`` CLI.
Combines hardware, policy, strategy, and runtime settings. The
``__post_init__`` method performs fail-fast validation to reject
invalid flag combinations early.
"""
# Hardware
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
# Policy (loaded from --policy.path via __post_init__)
policy: PreTrainedConfig | None = None
# Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger)
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
# Inference backend (polymorphic: --inference.type=sync|rtc)
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
# Dataset (required for sentry, highlight, dagger; None for base)
dataset: DatasetRecordConfig | None = None
# Runtime
fps: float = 30.0
duration: float = 0.0 # 0 = infinite (24/7 mode)
interpolation_multiplier: int = 1
device: str | None = None
task: str = ""
display_data: bool = False
# Display data on a remote Rerun server
display_ip: str | None = None
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
resume: bool = False
# Torch compile
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
torch_compile_mode: str = "default"
compile_warmup_inferences: int = 2
def __post_init__(self):
"""Validate config invariants and load the policy config from ``--policy.path``."""
# --- Strategy-specific validation ---
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set")
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
raise ValueError(
"Base strategy does not record data. Use sentry, highlight, or dagger for recording."
)
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
if (
isinstance(self.strategy, SentryStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Sentry mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# Highlight writes frames while the policy is still running, so streaming is mandatory.
if (
isinstance(self.strategy, HighlightStrategyConfig)
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("Highlight mode forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
if (
isinstance(self.strategy, DAggerStrategyConfig)
and self.strategy.record_autonomous
and self.dataset is not None
and not self.dataset.streaming_encoding
):
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
self.dataset.streaming_encoding = True
# --- Policy loading ---
if self.robot is None:
raise ValueError("--robot.type is required for rollout")
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
+429
View File
@@ -0,0 +1,429 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout context: shared state created once before strategy dispatch.
Grouped into five topical sub-contexts :class:`RuntimeContext`,
:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`,
and :class:`DatasetContext` assembled into :class:`RolloutContext`.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from threading import Event
import torch
from lerobot.configs import FeatureType, PreTrainedConfig
from lerobot.datasets import (
LeRobotDataset,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import (
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_processors,
rename_stats,
)
from lerobot.robots import make_robot_from_config
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .inference import (
InferenceEngine,
RTCInferenceConfig,
create_inference_engine,
)
from .robot_wrapper import ThreadSafeRobot
logger = logging.getLogger(__name__)
def _resolve_action_key_order(
policy_action_names: list[str] | None, dataset_action_names: list[str]
) -> list[str]:
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
if not policy_action_names:
return dataset_action_names
policy_action_names = list(policy_action_names)
if len(policy_action_names) != len(dataset_action_names):
logger.warning(
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
len(policy_action_names),
len(dataset_action_names),
)
return dataset_action_names
if set(dataset_action_names) != set(policy_action_names):
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
return dataset_action_names
return policy_action_names
# ---------------------------------------------------------------------------
# Sub-contexts
# ---------------------------------------------------------------------------
@dataclass
class RuntimeContext:
"""Runtime knobs shared with every strategy."""
cfg: RolloutConfig
shutdown_event: Event
@dataclass
class HardwareContext:
"""Connected hardware.
The raw robot is available via ``robot_wrapper.inner`` when needed
(e.g. for disconnect); strategies should otherwise go through the
thread-safe wrapper.
``initial_position`` stores the robot's joint positions at connect
time. Strategies use it to return the robot to a safe pose before
shutting down.
"""
robot_wrapper: ThreadSafeRobot
teleop: Teleoperator | None
initial_position: dict | None = None
@dataclass
class PolicyContext:
"""Loaded policy and its inference engine."""
policy: PreTrainedPolicy
preprocessor: PolicyProcessorPipeline
postprocessor: PolicyProcessorPipeline
inference: InferenceEngine
@dataclass
class ProcessorContext:
"""Robot-side pipelines (run outside the policy)."""
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation]
@dataclass
class DatasetContext:
"""Dataset and feature bookkeeping."""
dataset: LeRobotDataset | None
dataset_features: dict = field(default_factory=dict)
hw_features: dict = field(default_factory=dict)
ordered_action_keys: list[str] = field(default_factory=list)
@dataclass
class RolloutContext:
"""Bundle of sub-contexts passed to every rollout strategy.
Built once by :func:`build_rollout_context` before strategy dispatch.
"""
runtime: RuntimeContext
hardware: HardwareContext
policy: PolicyContext
processors: ProcessorContext
data: DatasetContext
# ---------------------------------------------------------------------------
# Build
# ---------------------------------------------------------------------------
def build_rollout_context(
cfg: RolloutConfig,
shutdown_event: Event,
teleop_action_processor: RobotProcessorPipeline | None = None,
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
) -> RolloutContext:
"""Wire up policy, processors, hardware, dataset, and inference engine.
The order is policy-first / hardware-last so a bad ``--policy.path``
fails fast without touching the robot.
"""
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
policy_config = cfg.policy
policy_class = get_policy_class(policy_config.type)
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
for attr in ("device", "use_amp"):
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
cli_val = getattr(cfg.policy, attr)
if cli_val is not None:
setattr(full_config, attr, cli_val)
if hasattr(full_config, "compile_model"):
full_config.compile_model = cfg.use_torch_compile
if full_config.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
if full_config.use_peft:
from peft import PeftConfig, PeftModel
peft_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
)
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
if is_rtc:
policy.config.rtc_config = cfg.inference.rtc
if hasattr(policy, "init_rtc_processor"):
policy.init_rtc_processor()
policy = policy.to(cfg.device)
policy.eval()
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
try:
if hasattr(torch, "compile"):
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
"options": {"triton.cudagraphs": False},
}
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
logger.info("torch.compile applied to predict_action_chunk")
except Exception as e:
logger.warning("Failed to apply torch.compile: %s", e)
# --- 2. Robot-side processors (user-supplied or defaults) --------
if (
teleop_action_processor is None
or robot_action_processor is None
or robot_observation_processor is None
):
_t, _r, _o = make_default_processors()
teleop_action_processor = teleop_action_processor or _t
robot_action_processor = robot_action_processor or _r
robot_observation_processor = robot_observation_processor or _o
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
robot = make_robot_from_config(cfg.robot)
robot.connect()
logger.info("Robot connected: %s", robot.name)
# Store the initial joint positions so we can return to a safe pose on shutdown.
initial_obs = robot.get_observation()
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
logger.info("Captured initial robot position (%d keys)", len(initial_position))
robot_wrapper = ThreadSafeRobot(robot)
teleop = None
if cfg.teleop is not None:
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
teleop = make_teleoperator_from_config(cfg.teleop)
teleop.connect()
logger.info("Teleoperator connected")
# DAgger requires teleop with motor control capabilities (enable_torque,
# disable_torque, write_goal_positions).
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
# if missing:
# teleop.disconnect()
# raise ValueError(
# f"DAgger strategy requires a teleoperator with motor control methods "
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
# )
# --- 4. Features + action-key reconciliation ---------------------
all_obs_features = robot.observation_features
observation_features_hw = {
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
}
action_features_hw = robot.action_features
# The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
action_dataset_features = aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=action_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
# Observation-side aggregation is needed because of build_dataset_frame
observation_dataset_features = aggregate_pipeline_dataset_features(
pipeline=robot_observation_processor,
initial_features=create_initial_features(observation=observation_features_hw),
use_videos=cfg.dataset.video if cfg.dataset else True,
)
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
raw_action_keys,
)
# Validate visual features if no rename_map is active
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
if not rename_map:
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {
f"observation.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
}
policy_subset = expected_visuals.issubset(provided_visuals)
hw_subset = provided_visuals.issubset(expected_visuals)
if not (policy_subset or hw_subset):
raise ValueError(
f"Visual feature mismatch between policy and robot hardware.\n"
f"Policy expects: {expected_visuals}\n"
f"Robot provides: {provided_visuals}"
)
# --- 5. Dataset -------------
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
)
else:
if isinstance(cfg.strategy, DAggerStrategyConfig):
dataset_features["intervention"] = {
"dtype": "bool",
"shape": (1,),
"names": None,
}
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
* len(robot.cameras if hasattr(robot, "cameras") else []),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
streaming_encoding=cfg.dataset.streaming_encoding,
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
encoder_threads=cfg.dataset.encoder_threads,
)
if dataset is not None:
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
dataset_stats = None
if dataset is not None:
dataset_stats = rename_stats(
dataset.meta.stats,
cfg.dataset.rename_map if cfg.dataset else {},
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy_config,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats,
preprocessor_overrides={
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
},
)
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
logger.info(
"Creating inference engine (type=%s)...",
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
inference_strategy = create_inference_engine(
cfg.inference,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
hw_features=hw_features,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task_str,
fps=cfg.fps,
device=cfg.device,
use_torch_compile=cfg.use_torch_compile,
compile_warmup_inferences=cfg.compile_warmup_inferences,
shutdown_event=shutdown_event,
)
# --- 8. Assemble ---------------------------------------------------
logger.info("Rollout context assembled successfully")
return RolloutContext(
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
hardware=HardwareContext(
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
),
policy=PolicyContext(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
inference=inference_strategy,
),
processors=ProcessorContext(
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
),
data=DatasetContext(
dataset=dataset,
dataset_features=dataset_features,
hw_features=hw_features,
ordered_action_keys=ordered_action_keys,
),
)
+39
View File
@@ -0,0 +1,39 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine package — backend-agnostic action production.
Concrete strategies (sync, RTC, ) expose the same small interface so
rollout strategies never branch on the inference backend.
"""
from .base import InferenceEngine
from .factory import (
InferenceEngineConfig,
RTCInferenceConfig,
SyncInferenceConfig,
create_inference_engine,
)
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
__all__ = [
"InferenceEngine",
"InferenceEngineConfig",
"RTCInferenceConfig",
"RTCInferenceEngine",
"SyncInferenceConfig",
"SyncInferenceEngine",
"create_inference_engine",
]
+88
View File
@@ -0,0 +1,88 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine ABC.
Rollout strategies consume actions through this small interface so they
do not need to know whether the inference engine is synchronous, runs in
a background thread (RTC), or comes from an external source.
"""
from __future__ import annotations
import abc
import torch
class InferenceEngine(abc.ABC):
"""Abstract backend for producing actions during rollout.
Subclasses decide whether inference happens inline, in a background
thread, or externally. The contract is minimal so new backends can
be added without touching rollout strategies.
Lifecycle
---------
``start`` prepare the backend (e.g. launch a background thread).
``stop`` shut the backend down cleanly.
``reset`` clear episode-scoped state (policy hidden state, queues).
Action production
-----------------
``get_action(obs_frame)`` return the next action tensor, or
``None`` if none is available (e.g. async queue empty). Sync
backends always compute from ``obs_frame``; async backends may
ignore it (they get observations via ``notify_observation``).
Optional hooks
--------------
``notify_observation`` / ``pause`` / ``resume`` have a no-op default
so rollout strategies can invoke them unconditionally.
"""
@abc.abstractmethod
def start(self) -> None:
"""Initialise the backend."""
@abc.abstractmethod
def stop(self) -> None:
"""Tear the backend down."""
@abc.abstractmethod
def reset(self) -> None:
"""Clear episode-scoped state."""
@abc.abstractmethod
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Return the next action tensor, or ``None`` if unavailable."""
def notify_observation(self, obs: dict) -> None: # noqa: B027
"""Publish the latest processed observation. Default: no-op."""
def pause(self) -> None: # noqa: B027
"""Pause background inference. Default: no-op."""
def resume(self) -> None: # noqa: B027
"""Resume background inference. Default: no-op."""
@property
def ready(self) -> bool:
"""True once the backend can produce actions (e.g. warmup done)."""
return True
@property
def failed(self) -> bool:
"""True if an unrecoverable error occurred in the backend."""
return False
+129
View File
@@ -0,0 +1,129 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference engine configs and factory.
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
backend requires registering its config subclass and dispatching it in
:func:`create_inference_engine`.
"""
from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
from threading import Event
import draccus
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configs
# ---------------------------------------------------------------------------
@dataclass
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for inference backend configuration.
Use ``--inference.type=<name>`` on the CLI to select a backend.
"""
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@InferenceEngineConfig.register_subclass("sync")
@dataclass
class SyncInferenceConfig(InferenceEngineConfig):
"""Inline synchronous inference (one policy call per control tick)."""
@InferenceEngineConfig.register_subclass("rtc")
@dataclass
class RTCInferenceConfig(InferenceEngineConfig):
"""Real-Time Chunking: async policy inference in a background thread."""
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
# constructing one here costs nothing and keeps draccus' CLI surface flat
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
rtc: RTCConfig = field(default_factory=RTCConfig)
queue_threshold: int = 30
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_inference_engine(
config: InferenceEngineConfig,
*,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
hw_features: dict,
dataset_features: dict,
ordered_action_keys: list[str],
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
shutdown_event: Event | None = None,
) -> InferenceEngine:
"""Instantiate the appropriate inference engine from a config object."""
logger.info("Creating inference engine: %s", config.type)
if isinstance(config, SyncInferenceConfig):
return SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task=task,
device=device,
robot_type=robot_wrapper.robot_type,
)
if isinstance(config, RTCInferenceConfig):
return RTCInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
robot_wrapper=robot_wrapper,
rtc_config=config.rtc,
hw_features=hw_features,
task=task,
fps=fps,
device=device,
use_torch_compile=use_torch_compile,
compile_warmup_inferences=compile_warmup_inferences,
rtc_queue_threshold=config.queue_threshold,
shutdown_event=shutdown_event,
)
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
+391
View File
@@ -0,0 +1,391 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking inference engine.
A background thread produces action chunks asynchronously via
:meth:`policy.predict_action_chunk`. The main control loop polls
``get_action`` for the next ready action; observations flow the other
way via ``notify_observation``.
"""
from __future__ import annotations
import logging
import math
import time
import traceback
from threading import Event, Lock, Thread
from typing import Any
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionQueue, LatencyTracker
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import (
NormalizerProcessorStep,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
to_relative_actions,
)
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceEngine
logger = logging.getLogger(__name__)
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
_RTC_IDLE_SLEEP_S: float = 0.01
# Backoff between transient inference errors (per consecutive failure).
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
# Consecutive transient errors tolerated before giving up and propagating shutdown.
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
# Hard timeout for joining the RTC thread on stop().
_RTC_JOIN_TIMEOUT_S: float = 3.0
# ---------------------------------------------------------------------------
# RTC helpers
# ---------------------------------------------------------------------------
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: torch.Tensor,
current_state: torch.Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> torch.Tensor:
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
is stored in absolute coordinates. Before feeding it back to the policy, this
helper re-expresses those actions relative to the robot's current joint state
and optionally normalizes them so the policy receives correctly scaled inputs.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
if prev_actions.ndim != 2:
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
steps, action_dim = prev_actions.shape
if steps == target_steps:
return prev_actions
if steps > target_steps:
return prev_actions[:target_steps]
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
padded[:steps] = prev_actions
return padded
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
class RTCInferenceEngine(InferenceEngine):
"""Async RTC inference: a background thread produces action chunks.
``get_action`` pops the next action from the shared queue (or
returns ``None`` if the queue is empty). The main loop should call
``notify_observation`` every tick and ``pause``/``resume`` around
human-intervention phases.
"""
def __init__(
self,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
robot_wrapper: ThreadSafeRobot,
rtc_config: RTCConfig,
hw_features: dict,
task: str,
fps: float,
device: str | None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
rtc_queue_threshold: int = 30,
shutdown_event: Event | None = None,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
self._postprocessor = postprocessor
self._robot = robot_wrapper
self._rtc_config = rtc_config
self._hw_features = hw_features
self._task = task
self._fps = fps
self._device = device or "cpu"
self._use_torch_compile = use_torch_compile
self._compile_warmup_inferences = compile_warmup_inferences
self._rtc_queue_threshold = rtc_queue_threshold
self._action_queue: ActionQueue | None = None
self._obs_holder: dict[str, Any] = {}
self._obs_lock = Lock()
self._policy_active = Event()
self._compile_warmup_done = Event()
self._shutdown_event = Event()
self._rtc_error = Event()
self._global_shutdown_event = shutdown_event
self._rtc_thread: Thread | None = None
if not self._use_torch_compile:
self._compile_warmup_done.set()
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
else:
logger.info(
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
compile_warmup_inferences,
)
# Processor introspection for relative-action re-anchoring.
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
self._normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if self._relative_step is not None:
if self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
if cfg_names:
self._relative_step.action_names = list(cfg_names)
else:
self._relative_step.action_names = [
k for k in robot_wrapper.action_features if k.endswith(".pos")
]
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
@property
def ready(self) -> bool:
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
return self._compile_warmup_done.is_set()
@property
def failed(self) -> bool:
"""True if the RTC background thread exited due to an unrecoverable error."""
return self._rtc_error.is_set()
@property
def action_queue(self) -> ActionQueue | None:
"""The shared action queue between the RTC thread and the main loop."""
return self._action_queue
def start(self) -> None:
"""Launch the RTC background thread."""
self._action_queue = ActionQueue(self._rtc_config)
self._obs_holder = {
"obs": None,
"robot_type": self._robot.robot_type,
}
self._shutdown_event.clear()
self._rtc_thread = Thread(
target=self._rtc_loop,
daemon=True,
name="RTCInference",
)
self._rtc_thread.start()
logger.info("RTC inference thread started")
def stop(self) -> None:
"""Signal the RTC thread to stop and wait for it."""
logger.info("Stopping RTC inference thread...")
self._shutdown_event.set()
self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
if self._rtc_thread.is_alive():
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
else:
logger.info("RTC inference thread stopped")
self._rtc_thread = None
def pause(self) -> None:
"""Pause the RTC background thread."""
logger.info("Pausing RTC inference thread")
self._policy_active.clear()
def resume(self) -> None:
"""Resume the RTC background thread."""
logger.info("Resuming RTC inference thread")
self._policy_active.set()
def reset(self) -> None:
"""Reset the policy, processors, and action queue."""
logger.info("Resetting RTC inference state (policy + processors + queue)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
if self._action_queue is not None:
self._action_queue.clear()
# ------------------------------------------------------------------
# Action production (called from main thread)
# ------------------------------------------------------------------
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
if self._action_queue is None:
return None
return self._action_queue.get()
def notify_observation(self, obs: dict) -> None:
"""Publish the latest observation for the RTC thread to consume."""
with self._obs_lock:
self._obs_holder["obs"] = obs
# ------------------------------------------------------------------
# RTC: background inference thread
# ------------------------------------------------------------------
def _rtc_loop(self) -> None:
"""Background thread that generates action chunks via RTC."""
try:
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / self._fps
policy_device = torch.device(self._device)
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0
consecutive_errors = 0
while not self._shutdown_event.is_set():
if not self._policy_active.is_set():
time.sleep(_RTC_IDLE_SLEEP_S)
continue
queue = self._action_queue
with self._obs_lock:
obs = self._obs_holder.get("obs")
if queue is None or obs is None:
time.sleep(_RTC_IDLE_SLEEP_S)
continue
if queue.qsize() <= self._rtc_queue_threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
prev_actions = queue.get_left_over()
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
obs_batch = prepare_observation_for_inference(
obs_batch, policy_device, self._task, self._robot.robot_type
)
obs_batch["task"] = [self._task]
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_abs,
current_state=state_tensor,
relative_step=self._relative_step,
normalizer_step=self._normalizer_step,
policy_device=policy_device,
)
if prev_actions is not None:
prev_actions = _normalize_prev_actions_length(
prev_actions, target_steps=self._rtc_config.execution_horizon
)
actions = self._policy.predict_action_chunk(
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
)
original = actions.squeeze(0).clone()
processed = self._postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
inference_count += 1
consecutive_errors = 0
is_warmup = self._use_torch_compile and inference_count <= warmup_required
if is_warmup:
latency_tracker.reset()
else:
latency_tracker.add(new_latency)
queue.merge(original, processed, new_delay, idx_before)
if (
is_warmup
and inference_count >= warmup_required
and not self._compile_warmup_done.is_set()
):
self._compile_warmup_done.set()
logger.info("Compile warmup complete (%d inferences)", inference_count)
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
except Exception as e:
consecutive_errors += 1
logger.error(
"RTC inference error (%d/%d): %s",
consecutive_errors,
_RTC_MAX_CONSECUTIVE_ERRORS,
e,
)
logger.debug(traceback.format_exc())
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
# Persistent failure: stop retrying and propagate shutdown.
raise
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
else:
time.sleep(_RTC_IDLE_SLEEP_S)
except Exception as e:
logger.error("Fatal error in RTC thread: %s", e)
logger.error(traceback.format_exc())
self._rtc_error.set()
# Unblock any warmup waiters so the main loop doesn't spin forever
self._compile_warmup_done.set()
# Signal the top-level shutdown so strategies exit their control loops
if self._global_shutdown_event is not None:
self._global_shutdown_event.set()
+107
View File
@@ -0,0 +1,107 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synchronous inference engine: inline policy call per control tick."""
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from .base import InferenceEngine
logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
``select_action``) on the given observation frame and returns a
CPU action tensor reordered to match the dataset action keys.
"""
def __init__(
self,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
dataset_features: dict,
ordered_action_keys: list[str],
task: str,
device: str | None,
robot_type: str,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
self._postprocessor = postprocessor
self._dataset_features = dataset_features
self._ordered_action_keys = ordered_action_keys
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
len(ordered_action_keys),
)
def start(self) -> None:
"""No background resources to start."""
logger.info("SyncInferenceEngine started (inline mode — no background thread)")
def stop(self) -> None:
"""No background resources to stop."""
logger.info("SyncInferenceEngine stopped")
def reset(self) -> None:
"""Reset the policy and pre/post-processors."""
logger.info("Resetting sync inference state (policy + processors)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
action = self._postprocessor(action)
action_tensor = action.squeeze(0).cpu()
# Reorder to match dataset action ordering so the caller can treat
# the returned tensor uniformly across backends.
action_dict = make_robot_action(action_tensor, self._dataset_features)
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
+112
View File
@@ -0,0 +1,112 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Memory-bounded ring buffer for the Highlight Reel rollout strategy."""
from __future__ import annotations
from collections import deque
import numpy as np
import torch
class RolloutRingBuffer:
"""Fixed-capacity circular buffer for observation/action frames.
Stores the last *N* seconds of telemetry in memory, bounded by both
time (``max_frames``) and memory (``max_memory_bytes``). When either
limit is reached the oldest frames are evicted.
.. note::
This class is **single-threaded**. ``append``/``drain``/``clear``
must all be called from the same thread (the rollout main loop).
Concurrent access from a background thread will corrupt
``_current_bytes`` accounting.
Parameters
----------
max_seconds:
Maximum duration of buffered telemetry.
max_memory_mb:
Hard memory cap in MiB. Frames are evicted when the estimated
total size exceeds this.
fps:
Frames per second used to convert ``max_seconds`` to a frame
count.
"""
def __init__(self, max_seconds: float = 30.0, max_memory_mb: float = 2048.0, fps: float = 30.0) -> None:
self._max_frames = int(max_seconds * fps)
self._max_bytes = int(max_memory_mb * 1024 * 1024)
self._buffer: deque[dict] = deque(maxlen=self._max_frames)
self._current_bytes: int = 0
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def append(self, frame: dict) -> None:
"""Add *frame* to the buffer, evicting the oldest if at capacity."""
frame_bytes = _estimate_frame_bytes(frame)
# Evict oldest frames until we are under the memory cap
while self._current_bytes + frame_bytes > self._max_bytes and self._buffer:
evicted = self._buffer.popleft()
self._current_bytes -= _estimate_frame_bytes(evicted)
self._buffer.append(frame)
self._current_bytes += frame_bytes
def drain(self) -> list[dict]:
"""Return all buffered frames and clear the buffer."""
frames = list(self._buffer)
self._buffer.clear()
self._current_bytes = 0
return frames
def clear(self) -> None:
"""Discard all buffered frames."""
self._buffer.clear()
self._current_bytes = 0
def __len__(self) -> int:
return len(self._buffer)
@property
def estimated_bytes(self) -> int:
"""Estimated total byte size of all buffered frames."""
return self._current_bytes
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _estimate_frame_bytes(frame: dict) -> int:
"""Rough byte estimate for a single frame dictionary."""
total = 0
for v in frame.values():
if isinstance(v, torch.Tensor):
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
# memory cap is honoured even when frames hold unconverted tensors.
total += v.nelement() * v.element_size()
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
total += v.nbytes
elif isinstance(v, (int, float)):
total += 8
elif isinstance(v, (str, bytes)):
total += len(v)
return max(total, 1) # avoid zero-size frames
+79
View File
@@ -0,0 +1,79 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thread-safe robot wrapper for concurrent observation/action access."""
from __future__ import annotations
from threading import Lock
from typing import Any
from lerobot.robots import Robot
class ThreadSafeRobot:
"""Lock-protected wrapper around a :class:`Robot` for use with background threads.
When RTC inference runs in a background thread while the main loop
executes actions, both threads may access the robot concurrently.
This wrapper serialises ``get_observation`` and ``send_action`` calls.
Read-only properties are proxied without the lock since they don't
mutate hardware state.
"""
def __init__(self, robot: Robot) -> None:
self._robot = robot
self._lock = Lock()
# -- Lock-protected I/O --------------------------------------------------
def get_observation(self) -> dict[str, Any]:
with self._lock:
return self._robot.get_observation()
def send_action(self, action: dict[str, Any] | Any) -> Any:
with self._lock:
return self._robot.send_action(action)
# -- Read-only proxies (no lock needed) -----------------------------------
@property
def observation_features(self) -> dict:
return self._robot.observation_features
@property
def action_features(self) -> dict:
return self._robot.action_features
@property
def name(self) -> str:
return self._robot.name
@property
def robot_type(self) -> str:
return self._robot.robot_type
@property
def cameras(self):
return getattr(self._robot, "cameras", {})
@property
def is_connected(self) -> bool:
return self._robot.is_connected
@property
def inner(self) -> Robot:
"""Access the underlying robot (e.g. for connect/disconnect)."""
return self._robot
@@ -0,0 +1,36 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategies — public API re-exports."""
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
__all__ = [
"BaseStrategy",
"DAggerEvents",
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
"estimate_max_episode_seconds",
"safe_push_to_hub",
"send_next_action",
]
+79
View File
@@ -0,0 +1,79 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base rollout strategy: autonomous policy execution with no data recording."""
from __future__ import annotations
import logging
import time
from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from .core import RolloutStrategy, send_next_action
logger = logging.getLogger(__name__)
class BaseStrategy(RolloutStrategy):
"""Autonomous policy rollout with no data recording.
All actions flow through the ``robot_action_processor`` pipeline
before reaching the robot.
"""
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine."""
self._init_engine(ctx)
logger.info("Base strategy ready")
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous control loop until shutdown or duration expires."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(cfg.fps)
start_time = time.perf_counter()
engine.resume()
logger.info("Base strategy control loop started")
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference."""
self._teardown_hardware(ctx.hardware)
logger.info("Base strategy teardown complete")
+272
View File
@@ -0,0 +1,272 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC and shared action-dispatch helper."""
from __future__ import annotations
import abc
import logging
import time
from typing import TYPE_CHECKING
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_rerun_data
from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, RolloutContext, RuntimeContext
logger = logging.getLogger(__name__)
class RolloutStrategy(abc.ABC):
"""Abstract base for rollout execution strategies.
Each concrete strategy implements a self-contained control loop with
its own recording/interaction semantics. Strategies are mutually
exclusive only one runs per session.
"""
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference engine and action interpolator, then start the backend.
Creates an :class:`ActionInterpolator` from the config's
``interpolation_multiplier`` and starts the inference engine.
Call this from ``setup()`` so strategies share identical
initialisation without duplicating code.
"""
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference
logger.info("Starting inference engine...")
self._engine.start()
self._warmup_flushed = False
logger.info("Inference engine started")
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.ready:
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
logger.info("Warmup complete — flushing stale state and resuming engine")
engine.reset()
interpolator.reset()
self._warmup_flushed = True
engine.resume()
return False
def _teardown_hardware(self, hw: HardwareContext) -> None:
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
if self._engine is not None:
logger.info("Stopping inference engine...")
self._engine.stop()
robot = hw.robot_wrapper.inner
if robot.is_connected:
if hw.initial_position:
logger.info("Returning robot to initial position before shutdown...")
self._return_to_initial_position(hw)
logger.info("Disconnecting robot...")
robot.disconnect()
teleop = hw.teleop
if teleop is not None and teleop.is_connected:
logger.info("Disconnecting teleoperator...")
teleop.disconnect()
@staticmethod
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
"""Smoothly interpolate the robot back to its initial position."""
robot = hw.robot_wrapper
target = hw.initial_position
try:
current_obs = robot.get_observation()
current_pos = {k: v for k, v in current_obs.items() if k in target}
steps = max(int(duration_s * fps), 1)
for step in range(1, steps + 1):
t = step / steps
interp = {}
for k in current_pos:
interp[k] = current_pos[k] * (1 - t) + target[k] * t
robot.send_action(interp)
precise_sleep(1 / fps)
except Exception as e:
logger.warning("Could not return to initial position: %s", e)
@staticmethod
def _log_telemetry(
obs_processed: dict | None,
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to Rerun if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_rerun_data(
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
)
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@abc.abstractmethod
def run(self, ctx: RolloutContext) -> None:
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
@abc.abstractmethod
def teardown(self, ctx: RolloutContext) -> None:
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
"""Push dataset to hub, skipping if no episodes have been saved.
Returns ``True`` if the push was attempted, ``False`` if skipped.
"""
if dataset.num_episodes == 0:
logger.warning("No episodes saved — skipping push to hub")
return False
dataset.push_to_hub(tags=tags, private=private)
return True
def estimate_max_episode_seconds(
dataset_features: dict,
fps: float,
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
) -> float:
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
Each camera produces its own video file, so the episode duration is
driven by the **slowest** camera to fill ``target_size_mb`` i.e.
the one with the fewest pixels per frame (lowest bitrate).
Uses a deliberately **low** bits-per-pixel estimate so the computed
duration is *longer* than reality. By the time the timer fires the
actual video file is guaranteed to have crossed the target size,
which aligns episode boundaries with the dataset's video-file
chunking each ``push_to_hub`` uploads complete files rather than
re-uploading a still-growing one.
The estimate ignores codec-specific settings (CRF, preset) on purpose:
we only need a rough lower bound on bitrate, not a precise prediction.
Falls back to 600 s (10 min) when no video features are present.
"""
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
# robot footage (real-world is typically 0.1 0.3 bpp). Under-
# estimating the bitrate over-estimates the time → the episode will be
# *larger* than target_size_mb when we save, which is what we want.
conservative_bpp = 0.1
# Collect per-camera pixel counts — each camera has its own video file.
camera_pixels = []
for feat in dataset_features.values():
if feat.get("dtype") == "video":
shape = feat.get("shape", ())
# Assuming shape could be (C, H, W) or (T, C, H, W)
# We want to extract the spatial dimensions.
if len(shape) >= 3:
h, w = shape[-2], shape[-1]
pixels = h * w
if pixels > 0:
camera_pixels.append(pixels)
if not camera_pixels:
return 600.0
# Use the smallest camera: it produces the lowest bitrate and therefore
# takes the longest to reach the target — the conservative choice.
min_pixels = min(camera_pixels)
bits_per_frame = min_pixels * conservative_bpp
bytes_per_second = (bits_per_frame * fps) / 8
# Guard against division by zero just in case
if bytes_per_second <= 0:
return 600.0
return (target_size_mb * 1024 * 1024) / bytes_per_second
# ---------------------------------------------------------------------------
# Shared action-dispatch helper
# ---------------------------------------------------------------------------
def send_next_action(
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
) -> dict | None:
"""Dispatch the next action to the robot.
Pulls the next action tensor from the inference engine, feeds the
interpolator, and sends the interpolated action through the
``robot_action_processor`` to the robot. Works identically for
sync and async backends the rollout strategy never needs to branch.
Returns the action dict that was sent, or ``None`` if no action was
ready (e.g. empty async queue, interpolator not yet primed).
"""
engine = ctx.policy.inference
features = ctx.data.dataset_features
ordered_keys = ctx.data.ordered_action_keys
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action(obs_frame)
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is None:
return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed)
return action_dict
+733
View File
@@ -0,0 +1,733 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DAgger rollout strategy: Human-in-the-Loop data collection.
Implements the RaC paradigm (Recovery and Correction) for interactive
imitation learning. Alternates between autonomous policy execution and
human intervention via teleoperator.
Input is controlled via either a keyboard or foot pedal, selected by
the ``input_device`` config field. Each device exposes three actions:
1. **pause_resume** Toggle policy execution (AUTONOMOUS <-> PAUSED).
2. **correction** Toggle correction recording (PAUSED <-> CORRECTING).
3. **upload** Push dataset to hub on demand (corrections-only mode).
ESC (keyboard only) Stop session.
Recording Modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Both autonomous and correction
frames are recorded; corrections tagged ``intervention=True``.
``record_autonomous=False``: Only correction windows are recorded.
Each correction (start to stop) becomes one episode.
"""
from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from typing import Any
import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.pedal import start_pedal_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
class DAggerPhase(enum.Enum):
"""Observable phases of a DAgger episode."""
AUTONOMOUS = "autonomous" # Policy driving
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
CORRECTING = "correcting" # Human driving via teleop, recording interventions
# Valid (current_phase, event) -> next_phase
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
}
class DAggerEvents:
"""Thread-safe container for DAgger input device events.
The keyboard/pedal threads write transition requests; the main loop
consumes them.
"""
def __init__(self) -> None:
self._lock = Lock()
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition: str | None = None
# Session-level flags
self.stop_recording = Event()
self.upload_requested = Event()
# -- Thread-safe phase access ------------------------------------------
@property
def phase(self) -> DAggerPhase:
"""Current phase of the DAgger state machine."""
with self._lock:
return self._phase
@phase.setter
def phase(self, value: DAggerPhase) -> None:
with self._lock:
self._phase = value
def request_transition(self, event: str) -> None:
"""Request a phase transition (called from keyboard/pedal threads).
Only enqueues the request if it corresponds to a valid transition
from the current phase, preventing impossible state changes.
"""
with self._lock:
if (self._phase, event) in _DAGGER_TRANSITIONS:
self._pending_transition = event
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
"""Consume a pending transition (called from main loop)."""
with self._lock:
if self._pending_transition is None:
return None
key = (self._phase, self._pending_transition)
self._pending_transition = None
new_phase = _DAGGER_TRANSITIONS.get(key)
if new_phase is None:
return None
old_phase = self._phase
self._phase = new_phase
return old_phase, new_phase
def reset(self) -> None:
"""Reset all transient state for a fresh session."""
with self._lock:
self._phase = DAggerPhase.AUTONOMOUS
self._pending_transition = None
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
) -> None:
"""Smoothly move teleop to target position via linear interpolation.
Requires the teleoperator to support motor control methods
(``enable_torque``, ``write_goal_positions``, ``get_action``).
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
"""Initialise keyboard listener with DAgger 3-key controls.
Returns the pynput Listener (or ``None`` in headless mode or when
pynput is unavailable).
"""
if not PYNPUT_AVAILABLE or is_headless():
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
return None
# Map config key names to pynput Key objects for special keys
special_keys = {
"space": keyboard.Key.space,
"tab": keyboard.Key.tab,
"enter": keyboard.Key.enter,
}
def _resolve_key(key) -> str | None:
"""Resolve a pynput key event to a config-comparable string."""
if key == keyboard.Key.esc:
return "esc"
for name, pynput_key in special_keys.items():
if key == pynput_key:
return name
if hasattr(key, "char") and key.char:
return key.char
return None
# Build mapping: resolved key string -> DAgger event name
key_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(key):
try:
resolved = _resolve_key(key)
if resolved is None:
return
if resolved == "esc":
logger.info("Stop recording...")
events.stop_recording.set()
return
if resolved in key_to_event:
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
except Exception as e:
logger.debug("Key error: %s", e)
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
)
return listener
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
"""Initialise foot pedal listener with DAgger 3-pedal controls.
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
"""
code_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(code: str) -> None:
if code in code_to_event:
events.request_transition(code_to_event[code])
if code == cfg.upload:
events.upload_requested.set()
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
return start_pedal_listener(on_press, device_path=cfg.device_path)
# ---------------------------------------------------------------------------
# DAgger Strategy
# ---------------------------------------------------------------------------
class DAggerStrategy(RolloutStrategy):
"""Human-in-the-Loop data collection with intervention tagging.
State machine::
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
--(key1)--> AUTONOMOUS
Recording modes:
``record_autonomous=True``: Sentry-like continuous recording with
time-based episode rotation. Intervention frames tagged True.
``record_autonomous=False``: Only correction windows recorded.
Each correction = one episode. Upload on demand via key3.
"""
config: DAggerStrategyConfig
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._listener = None
self._pedal_thread = None
self._events = DAggerEvents()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine and input device listener."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
if self.config.input_device == "keyboard":
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
else:
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
logger.info(
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
self.config.input_device,
self.config.num_episodes,
record_mode,
self._episode_duration_s,
)
def run(self, ctx: RolloutContext) -> None:
"""Run DAgger episodes with human-in-the-loop intervention."""
if self.config.record_autonomous:
self._run_continuous(ctx)
else:
self._run_corrections_only(ctx)
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping DAgger recording")
log_say("Stopping DAgger recording", play_sounds)
if self._listener is not None and not is_headless():
logger.info("Stopping keyboard listener")
self._listener.stop()
# Flush any queued/running push cleanly
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
# Continuous recording mode (record_autonomous=True)
# ------------------------------------------------------------------
def _run_continuous(self, ctx: RolloutContext) -> None:
"""Sentry-like continuous recording with intervention tagging.
Episodes are auto-rotated every ``episode_time_s`` seconds and
uploaded in the background every ``upload_every_n_episodes`` episodes.
Both autonomous and correction frames are recorded; corrections are
tagged with ``intervention=True``.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
events.reset()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
record_tick = 0
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
episode_duration_s = self._episode_duration_s
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
# --- CORRECTING: human teleop control ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([False], dtype=bool),
}
dataset.add_frame(frame)
record_tick += 1
# Episode rotation derived from video file-size target.
# Do NOT save mid-correction — wait for the correction
# to finish so the episode boundary is clean.
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
episodes_since_push = 0
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# Corrections-only mode (record_autonomous=False)
# ------------------------------------------------------------------
def _run_corrections_only(self, ctx: RolloutContext) -> None:
"""Record only human correction windows. Each correction = one episode.
The policy runs autonomously without recording. When the user
pauses and starts a correction, frames are recorded with
``intervention=True``. Stopping the correction saves the episode.
The dataset can be uploaded on demand via the upload key/pedal.
"""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
events.reset()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
record_tick = 0
recorded = 0
logger.info(
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
)
with VideoEncodingManager(dataset):
try:
while (
recorded < self.config.num_episodes
and not events.stop_recording.is_set()
and not ctx.runtime.shutdown_event.is_set()
):
loop_start = time.perf_counter()
# Process transitions
transition = events.consume_transition()
if transition is not None:
old_phase, new_phase = transition
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
last_action = None
# Correction ended -> save episode (blocking if not streaming)
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
with self._episode_lock:
dataset.save_episode()
recorded += 1
self._needs_push.set()
logger.info(
"Correction %d/%d saved",
recorded,
self.config.num_episodes,
)
log_say(f"Correction {recorded} saved", play_sounds)
# On-demand upload
if events.upload_requested.is_set():
events.upload_requested.clear()
logger.info("Upload requested by user")
self._background_push(dataset, cfg)
phase = events.phase
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
# --- CORRECTING: human teleop control + recording ---
if phase == DAggerPhase.CORRECTING:
teleop_action = teleop.get_action()
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
dataset.add_frame(
{
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([True], dtype=bool),
}
)
record_tick += 1
# --- PAUSED: hold position ---
elif phase == DAggerPhase.PAUSED:
if last_action:
robot.send_action(last_action)
# --- AUTONOMOUS: policy control (no recording) ---
else:
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# State-machine transition side-effects
# ------------------------------------------------------------------
@staticmethod
def _apply_transition(
old_phase: DAggerPhase,
new_phase: DAggerPhase,
engine,
interpolator,
robot: ThreadSafeRobot,
teleop: Teleoperator,
) -> None:
"""Execute side-effects for a validated phase transition."""
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
logger.info("Pausing engine — robot holds position")
engine.pause()
obs = robot.get_observation()
_robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode — human teleop control")
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
logger.info("Resuming autonomous mode — resetting engine and interpolator")
interpolator.reset()
engine.reset()
engine.resume()
# ------------------------------------------------------------------
# Background push (shared by both modes)
# ------------------------------------------------------------------
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped. Pushes
are blocked while the operator is mid-correction to avoid
uploading a partially-recorded episode.
"""
if self._push_executor is None:
return
if self._events.phase == DAggerPhase.CORRECTING:
logger.info("Skipping push — correction in progress")
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Strategy factory: config type-name → strategy class dispatch."""
from __future__ import annotations
from typing import TYPE_CHECKING
from .base import BaseStrategy
from .core import RolloutStrategy
from .dagger import DAggerStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
if TYPE_CHECKING:
from lerobot.rollout import RolloutStrategyConfig
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
"""Instantiate the appropriate strategy from a config object.
Dispatches on ``config.type`` (the name registered via
``draccus.ChoiceRegistry``).
"""
if config.type == "base":
return BaseStrategy(config)
if config.type == "sentry":
return SentryStrategy(config)
if config.type == "highlight":
return HighlightStrategy(config)
if config.type == "dagger":
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
+277
View File
@@ -0,0 +1,277 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Highlight Reel strategy: on-demand recording via ring buffer."""
from __future__ import annotations
import contextlib
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
class HighlightStrategy(RolloutStrategy):
"""Autonomous rollout with on-demand recording via ring buffer.
The robot runs autonomously while a memory-bounded ring buffer
captures continuous telemetry. When the user presses the save key:
1. The ring buffer is flushed to the dataset (last *Z* seconds).
2. Live recording continues until the save key is pressed again.
3. The episode is saved and the ring buffer resumes capturing.
Requires ``streaming_encoding=True`` (enforced in config validation)
so that ``dataset.add_frame`` is a non-blocking queue put draining
900 frames stays sub-ms per frame.
"""
config: HighlightStrategyConfig
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
require_package("pynput", extra="pynput-dep")
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = ThreadingEvent()
self._recording_live = ThreadingEvent()
self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine, ring buffer, and keyboard listener."""
self._init_engine(ctx)
self._ring = RolloutRingBuffer(
max_seconds=self.config.ring_buffer_seconds,
max_memory_mb=self.config.ring_buffer_max_memory_mb,
fps=ctx.runtime.cfg.fps,
)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
logger.info(
"Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)",
self.config.ring_buffer_seconds,
self.config.ring_buffer_max_memory_mb,
)
self._setup_keyboard(ctx.runtime.shutdown_event)
logger.info(
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
self.config.ring_buffer_seconds,
self.config.save_key,
self.config.push_key,
)
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous loop, buffering frames and recording on demand."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
dataset = ctx.data.dataset
ring = self._ring
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
start_time = time.perf_counter()
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key)
with VideoEncodingManager(dataset):
try:
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# NOTE: ``is_set()`` then ``clear()`` is not atomic
# against the keyboard thread setting the flag again
# in between — but that is benign: we lose at most one
# toggle, processed on the next iteration. The
# ``_recording_live`` branch below is reached in the
# SAME iteration after ``clear()`` runs, so a frame
# finalised by ``save_episode()`` is never re-added to
# the next episode.
if self._save_requested.is_set():
self._save_requested.clear()
if not self._recording_live.is_set():
logger.info(
"Flushing ring buffer (%d frames) + starting live recording",
len(ring),
)
for buffered_frame in ring.drain():
dataset.add_frame(buffered_frame)
self._recording_live.set()
else:
dataset.add_frame(frame)
dataset.save_episode()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say(
f"Episode {dataset.num_episodes} saved",
play_sounds,
)
self._recording_live.clear()
if self._push_requested.is_set():
self._push_requested.clear()
logger.info("Push requested by user")
self._background_push(dataset, cfg)
if self._recording_live.is_set():
dataset.add_frame(frame)
else:
ring.append(frame)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("Highlight control loop ended")
if self._recording_live.is_set():
logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping highlight recording")
log_say("Stopping highlight recording", play_sounds)
if self._listener is not None:
logger.info("Stopping keyboard listener")
self._listener.stop()
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Highlight strategy teardown complete")
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
"""Set up keyboard listener for save and push keys."""
if is_headless():
logger.warning("Headless environment — highlight keys unavailable")
return
try:
save_key = self.config.save_key
push_key = self.config.push_key
def on_press(key):
with contextlib.suppress(Exception):
if hasattr(key, "char") and key.char == save_key:
self._save_requested.set()
elif hasattr(key, "char") and key.char == push_key:
self._push_requested.set()
elif key == keyboard.Key.esc:
self._save_requested.clear()
shutdown_event.set()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor."""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+225
View File
@@ -0,0 +1,225 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sentry rollout strategy: continuous autonomous recording with auto-upload."""
from __future__ import annotations
import contextlib
import logging
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
class SentryStrategy(RolloutStrategy):
"""Continuous autonomous rollout with always-on recording.
Episode duration is derived from camera resolution, FPS, and
``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode
produces a video file that has crossed the chunk-size boundary.
This keeps ``push_to_hub`` efficient it uploads complete video
files rather than re-uploading a still-growing one.
The dataset is pushed to the Hub via a bounded single-worker executor
so no push is ever silently dropped and exactly one push runs at a
time.
Policy state (hidden state, RTC queue) intentionally persists across
episode boundaries Sentry slices one continuous rollout, the robot
does not reset between slices.
Requires ``streaming_encoding=True`` (enforced in config validation)
to prevent disk I/O from blocking the control loop.
"""
config: SentryStrategyConfig
def __init__(self, config: SentryStrategyConfig):
super().__init__(config)
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
self._needs_push = Event()
self._episode_lock = Lock()
def setup(self, ctx: RolloutContext) -> None:
"""Initialise the inference engine and background push executor."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
logger.info(
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
self._episode_duration_s,
self.config.upload_every_n_episodes,
)
def run(self, ctx: RolloutContext) -> None:
"""Run the continuous recording loop with automatic episode rotation."""
engine = self._engine
cfg = ctx.runtime.cfg
robot = ctx.hardware.robot_wrapper
dataset = ctx.data.dataset
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
episode_duration_s = self._episode_duration_s
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
obs_processed = ctx.processors.robot_observation_processor(obs)
engine.notify_observation(obs_processed)
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# ``add_frame`` writes to the in-progress episode buffer; the
# background pusher only ever touches *finalised* episode
# artifacts on disk. The two operate on disjoint state, so
# ``add_frame`` does not need ``_episode_lock``.
dataset.add_frame(frame)
# Episode rotation derived from video file-size target.
# The duration is a conservative estimate so the actual
# video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now,
# keeping push_to_hub efficient (uploads complete files).
elapsed = time.perf_counter() - episode_start
if elapsed >= episode_duration_s:
# ``save_episode`` finalises the in-progress episode and
# flushes it to disk; ``_episode_lock`` serialises this with
# ``push_to_hub`` (run in the background executor) so the
# pusher never reads a half-written episode.
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
episodes_since_push = 0
episode_start = time.perf_counter()
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
finally:
logger.info("Sentry control loop ended — saving final episode")
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
def teardown(self, ctx: RolloutContext) -> None:
"""Flush pending pushes, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping sentry recording")
log_say("Stopping sentry recording", play_sounds)
# Flush any queued/running push cleanly.
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Sentry strategy teardown complete")
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped.
"""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+83 -256
View File
@@ -13,70 +13,62 @@
# limitations under the License.
"""
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
Records a dataset via teleoperation. This is a pure data-collection
tool no policy inference. For deploying trained policies, use
``lerobot-rollout`` instead.
Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras)
Example:
```shell
lerobot-record \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--robot.id=black \
--dataset.repo_id=<my_username>/<my_dataset_name> \
--dataset.num_episodes=2 \
--dataset.single_task="Grab the cube" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
lerobot-record \\
--robot.type=so100_follower \\
--robot.port=/dev/tty.usbmodem58760431541 \\
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
--robot.id=black \\
--teleop.type=so100_leader \\
--teleop.port=/dev/tty.usbmodem58760431551 \\
--teleop.id=blue \\
--dataset.repo_id=<my_username>/<my_dataset_name> \\
--dataset.num_episodes=2 \\
--dataset.single_task="Grab the cube" \\
--dataset.streaming_encoding=true \\
--dataset.encoder_threads=2 \\
--display_data=true
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
# --dataset.vcodec=h264 \
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
# --teleop.type=so100_leader \
# --teleop.port=/dev/tty.usbmodem58760431551 \
# --teleop.id=blue \
# <- Policy optional if you want to record with a policy \
# --policy.path=${HF_USER}/my_policy \
```
Example recording with bimanual so100:
```shell
lerobot-record \
--robot.type=bi_so_follower \
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \
--robot.id=bimanual_follower \
lerobot-record \\
--robot.type=bi_so_follower \\
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \\
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \\
--robot.id=bimanual_follower \\
--robot.left_arm_config.cameras='{
wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30},
top: {"type": "opencv", "index_or_path": 3, "width": 640, "height": 480, "fps": 30},
}' --robot.right_arm_config.cameras='{
wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
front: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
}' \
--teleop.type=bi_so_leader \
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \
--teleop.id=bimanual_leader \
--display_data=true \
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
--dataset.num_episodes=25 \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
}' \\
--teleop.type=bi_so_leader \\
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \\
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \\
--teleop.id=bimanual_leader \\
--display_data=true \\
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \\
--dataset.num_episodes=25 \\
--dataset.single_task="Grab and handover the red cube to the other arm" \\
--dataset.streaming_encoding=true \\
--dataset.encoder_threads=2
```
"""
import logging
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from dataclasses import asdict, dataclass
from pprint import pformat
from typing import Any
import torch
from lerobot.cameras import CameraConfig # noqa: F401
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
@@ -86,11 +78,10 @@ from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.common.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
@@ -98,21 +89,11 @@ from lerobot.datasets import (
create_initial_features,
safe_stop_image_writer,
)
from lerobot.policies import (
ActionInterpolator,
PreTrainedPolicy,
make_policy,
make_pre_post_processors,
make_robot_action,
)
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_processors,
rename_stats,
)
from lerobot.robots import ( # noqa: F401
Robot,
@@ -146,7 +127,6 @@ from lerobot.teleoperators import ( # noqa: F401
)
from lerobot.teleoperators.keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
@@ -157,71 +137,12 @@ from lerobot.utils.utils import (
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
@dataclass
class RecordConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Whether to control the robot with a teleoperator
# Teleoperator to control the robot (required)
teleop: TeleoperatorConfig | None = None
# Whether to control the robot with a policy
policy: PreTrainedConfig | None = None
# Display all cameras on screen
display_data: bool = False
# Display data on a remote Rerun server
@@ -234,27 +155,14 @@ class RecordConfig:
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
# Only applies when using a policy (not teleop)
interpolation_multiplier: int = 1
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
if self.teleop is None:
raise ValueError(
"A teleoperator is required for recording. "
"Use --teleop.type=... to specify one. "
"For policy-based deployment, use lerobot-rollout instead."
)
""" --------------- record_loop() data flow --------------------------
@@ -264,18 +172,14 @@ class RecordConfig:
V
[ robot_observation_processor ] ---> processed_obs
V
.-----( ACTION LOGIC )------------------.
V V
[ From Teleoperator ] [ From Policy ]
| |
| [teleop.get_action] -> raw_action | [predict_action]
| | | |
| V | V
| [teleop_action_processor] | |
| | | |
'---> processed_teleop_action '---> processed_policy_action
| |
'-------------------------.-------------'
[ Teleoperator ]
|
| [teleop.get_action] -> raw_action
| |
| V
| [teleop_action_processor]
| |
'---> processed_teleop_action
V
[ robot_action_processor ] --> robot_action_to_send
V
@@ -303,13 +207,9 @@ def record_loop(
], # runs after robot
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
interpolator: ActionInterpolator | None = None,
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -340,21 +240,7 @@ def record_loop(
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
)
# Reset policy and processor if they are provided
if policy is not None and preprocessor is not None and postprocessor is not None:
policy.reset()
preprocessor.reset()
postprocessor.reset()
# Reset interpolator if provided
if interpolator is not None:
interpolator.reset()
# Calculate control interval based on interpolation
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
# Pre-compute action key order outside the hot loop — it won't change mid-episode.
action_keys = sorted(robot.action_features) if use_interpolation else []
control_interval = 1 / fps
no_action_count = 0
timestamp = 0
@@ -372,63 +258,11 @@ def record_loop(
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
obs_processed = robot_observation_processor(obs)
if policy is not None or dataset is not None:
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Track whether this iteration should be recorded to the dataset.
# Interpolated-only iterations send actions to the robot but don't record frames,
# keeping the dataset at the original fps while the robot moves at the higher rate.
is_record_frame = True
# Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None:
# With interpolation: only call policy when interpolator needs new action
if use_interpolation:
ran_inference = False
if interpolator.needs_new_action():
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy = make_robot_action(action_values, dataset.features)
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
interpolator.add(action_tensor)
ran_inference = True
interp_action = interpolator.get()
if interp_action is not None:
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
action_values = robot_action_to_send
else:
continue
is_record_frame = ran_inference
else:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
# Applies a pipeline to the action, default is IdentityProcessor
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
action_values = robot_action_to_send
elif policy is None and isinstance(teleop, Teleoperator):
# Get action from teleop
if isinstance(teleop, Teleoperator):
act = teleop.get_action()
if robot.name == "unitree_g1":
teleop.send_feedback(obs)
@@ -438,7 +272,7 @@ def record_loop(
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
elif policy is None and isinstance(teleop, list):
elif isinstance(teleop, list):
arm_action = teleop_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
keyboard_action = teleop_keyboard.get_action()
@@ -451,7 +285,7 @@ def record_loop(
no_action_count += 1
if no_action_count == 1 or no_action_count % 10 == 0:
logging.warning(
"No policy or teleoperator provided, skipping action generation. "
"No teleoperator provided, skipping action generation. "
"This is likely to happen when resetting the environment without a teleop device. "
"The robot won't be at its rest position at the start of the next episode."
)
@@ -463,8 +297,8 @@ def record_loop(
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset (only on real policy frames, not interpolated-only iterations)
if dataset is not None and is_record_frame:
# Write to dataset
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
@@ -488,7 +322,12 @@ def record_loop(
@parser.wrap()
def record(cfg: RecordConfig) -> LeRobotDataset:
def record(
cfg: RecordConfig,
teleop_action_processor: RobotProcessorPipeline | None = None,
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
) -> LeRobotDataset:
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
@@ -502,7 +341,16 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Fall back to identity pipelines when the caller doesn't supply processors.
if (
teleop_action_processor is None
or robot_action_processor is None
or robot_observation_processor is None
):
_t, _r, _o = make_default_processors()
teleop_action_processor = teleop_action_processor or _t
robot_action_processor = robot_action_processor or _r
robot_observation_processor = robot_observation_processor or _o
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
@@ -540,8 +388,12 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
# Reject eval_ prefix — for policy evaluation use lerobot-rollout
if cfg.dataset.repo_id.startswith("eval_"):
raise ValueError(
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
)
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
@@ -558,30 +410,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
encoder_threads=cfg.dataset.encoder_threads,
)
# Load pretrained policy
policy = (
None
if cfg.policy is None
else make_policy(cfg.policy, ds_meta=dataset.meta, rename_map=cfg.dataset.rename_map)
)
preprocessor = None
postprocessor = None
interpolator = None
if cfg.policy is not None:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
},
)
# Create interpolator for smoother policy control
if cfg.interpolation_multiplier > 1:
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
robot.connect()
if teleop is not None:
teleop.connect()
@@ -605,14 +433,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
interpolator=interpolator,
display_compressed_images=display_compressed_images,
)
@@ -660,7 +484,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
if dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
else:
logging.warning("No episodes saved — skipping push to hub")
log_say("Exiting", cfg.play_sounds)
return dataset
+211
View File
@@ -0,0 +1,211 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Policy deployment engine with pluggable rollout strategies.
``lerobot-rollout`` is the single CLI for running trained policies on
real robots.
Strategies
----------
--strategy.type=base Autonomous rollout, no recording
--strategy.type=sentry Continuous recording with auto-upload
--strategy.type=highlight Ring buffer + keystroke save
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
Inference backends
------------------
--inference.type=sync One policy call per control tick (default)
--inference.type=rtc Real-Time Chunking for slow VLA models
Usage examples
--------------
::
# Base mode — quick evaluation with sync inference
lerobot-rollout \\
--strategy.type=base \\
--policy.path=lerobot/act_koch_real \\
--robot.type=koch_follower \\
--robot.port=/dev/ttyACM0 \\
--task="pick up cube" --duration=30
# Base mode — RTC inference for slow VLAs (Pi0, Pi0.5, SmolVLA)
lerobot-rollout \\
--strategy.type=base \\
--policy.path=lerobot/pi0_base \\
--inference.type=rtc \\
--inference.rtc.execution_horizon=10 \\
--inference.rtc.max_guidance_weight=10.0 \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
--task="pick up cube" --duration=60
# Sentry mode — continuous recording with periodic upload
lerobot-rollout \\
--strategy.type=sentry \\
--strategy.upload_every_n_episodes=5 \\
--policy.path=lerobot/pi0_base \\
--inference.type=rtc \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--dataset.repo_id=user/sentry-data \\
--dataset.single_task="patrol" --duration=3600
# Highlight mode — ring buffer, press 's' to save, 'h' to push
lerobot-rollout \\
--strategy.type=highlight \\
--strategy.ring_buffer_seconds=30 \\
--policy.path=lerobot/act_koch_real \\
--robot.type=koch_follower \\
--robot.port=/dev/ttyACM0 \\
--dataset.repo_id=user/highlight-data \\
--dataset.single_task="pick up cube"
# DAgger mode — human-in-the-loop corrections only
lerobot-rollout \\
--strategy.type=dagger \\
--strategy.num_episodes=20 \\
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \\
--robot.type=bi_openarm_follower \\
--teleop.type=openarm_mini \\
--dataset.repo_id=user/hil-data \\
--dataset.single_task="Fold the T-shirt"
# DAgger mode — continuous recording with RTC inference
lerobot-rollout \\
--strategy.type=dagger \\
--strategy.record_autonomous=true \\
--strategy.num_episodes=50 \\
--inference.type=rtc \\
--inference.rtc.execution_horizon=10 \\
--policy.path=user/my_pi0_policy \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--teleop.type=so101_leader \\
--teleop.port=/dev/ttyACM1 \\
--dataset.repo_id=user/dagger-rtc-data \\
--dataset.single_task="Grasp the block"
# With Rerun visualization and torch.compile
lerobot-rollout \\
--strategy.type=base \\
--policy.path=lerobot/act_koch_real \\
--robot.type=koch_follower \\
--robot.port=/dev/ttyACM0 \\
--task="pick up cube" --duration=60 \\
--display_data=true \\
--use_torch_compile=true
# Resume a previous sentry recording session
lerobot-rollout \\
--strategy.type=sentry \\
--policy.path=user/my_policy \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--dataset.repo_id=user/sentry-data \\
--dataset.single_task="patrol" \\
--resume=true
"""
import logging
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
omx_follower,
openarm_follower,
reachy2,
so_follower,
unitree_g1 as unitree_g1_robot,
)
from lerobot.rollout import RolloutConfig, build_rollout_context, create_strategy
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_so_leader,
homunculus,
koch_leader,
omx_leader,
openarm_leader,
openarm_mini,
reachy2_teleoperator,
so_leader,
unitree_g1,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
from lerobot.utils.visualization_utils import init_rerun
logger = logging.getLogger(__name__)
@parser.wrap()
def rollout(cfg: RolloutConfig):
"""Main entry point for policy deployment."""
init_logging()
if cfg.display_data:
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
logger.info("Building rollout context...")
ctx = build_rollout_context(cfg, shutdown_event)
strategy = create_strategy(cfg.strategy)
logger.info("Rollout strategy: %s", cfg.strategy.type)
logger.info(
"Robot: %s | FPS: %.0f | Duration: %s",
cfg.robot.type if cfg.robot else "?",
cfg.fps,
f"{cfg.duration}s" if cfg.duration > 0 else "infinite",
)
try:
strategy.setup(ctx)
logger.info("Rollout setup complete, starting rollout...")
strategy.run(ctx)
except KeyboardInterrupt:
logger.info("Interrupted by user")
finally:
strategy.teardown(ctx)
logger.info("Rollout finished")
def main():
"""CLI entry point for ``lerobot-rollout``."""
register_third_party_plugins()
rollout()
if __name__ == "__main__":
main()
+116
View File
@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
+83
View File
@@ -0,0 +1,83 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic foot pedal listener using evdev.
Callers supply a callback receiving the pressed key code (e.g. ``"KEY_A"``)
and an optional device path. The listener runs in a daemon thread and
silently no-ops when :mod:`evdev` is not installed or the device is
unavailable. Strategy-specific key mapping logic lives in the caller.
"""
from __future__ import annotations
import logging
import threading
from collections.abc import Callable
logger = logging.getLogger(__name__)
DEFAULT_PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
def start_pedal_listener(
on_press: Callable[[str], None],
device_path: str = DEFAULT_PEDAL_DEVICE,
) -> threading.Thread | None:
"""Spawn a daemon thread that forwards pedal key-press codes to ``on_press``.
Parameters
----------
on_press:
Callback invoked with the pressed key code string (e.g. ``"KEY_A"``)
on each pedal press event. The callback runs in the listener thread
and must be thread-safe.
device_path:
Linux input device path (e.g. ``/dev/input/by-id/...``).
Returns
-------
The started daemon :class:`threading.Thread`, or ``None`` when
:mod:`evdev` is not installed (optional dependency; silent no-op).
"""
try:
from evdev import InputDevice, categorize, ecodes
except ImportError:
return None
def pedal_reader() -> None:
try:
dev = InputDevice(device_path)
logger.info("Pedal connected: %s", dev.name)
for ev in dev.read_loop():
if ev.type != ecodes.EV_KEY:
continue
key = categorize(ev)
code = key.keycode
if isinstance(code, (list, tuple)):
code = code[0]
if key.keystate != 1: # only key-down events
continue
try:
on_press(code)
except Exception as cb_err: # pragma: no cover - defensive
logger.warning("Pedal callback error: %s", cb_err)
except (FileNotFoundError, PermissionError):
pass
except Exception as e:
logger.warning("Pedal error: %s", e)
thread = threading.Thread(target=pedal_reader, daemon=True, name="PedalListener")
thread.start()
return thread
@@ -17,9 +17,9 @@
import pytest
import torch
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.action_interpolator import ActionInterpolator
# ====================== Fixtures ======================
-82
View File
@@ -24,10 +24,6 @@ def lerobot_train(args):
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
def lerobot_record(args):
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
def resolve_model_id_for_peft_training(policy_type):
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
if policy_type == "smolvla":
@@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
f"--output_dir={output_dir}",
]
)
class DummyRobot:
name = "dummy"
cameras = []
action_features = {"foo": 1.0, "bar": 2.0}
observation_features = {"obs1": 1.0, "obs2": 2.0}
is_connected = True
def connect(self, *args):
pass
def disconnect(self):
pass
def dummy_make_robot_from_config(*args, **kwargs):
return DummyRobot()
@pytest.mark.parametrize("policy_type", ["smolvla"])
@skip_if_package_missing("peft")
def test_peft_record_loads_policy(policy_type, tmp_path):
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
from peft import PeftModel
output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train(
[
f"--policy.path={model_id}",
"--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
"--steps=1",
f"--output_dir={output_dir}",
]
)
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
dataset_dir = tmp_path / "eval_pusht"
single_task = "move the table"
loaded_policy = None
def dummy_record_loop(*args, **kwargs):
nonlocal loaded_policy
if "dataset" not in kwargs:
return
dataset = kwargs["dataset"]
dataset.add_frame({"task": single_task})
loaded_policy = kwargs["policy"]
with (
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
# disable record loop since we're only interested in successful loading of the policy.
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
# disable speech output
patch("lerobot.utils.utils.say"),
):
lerobot_record(
[
f"--policy.path={policy_dir}",
"--robot.type=so101_follower",
"--robot.port=/dev/null",
"--dataset.repo_id=lerobot/eval_pusht",
f'--dataset.single_task="{single_task}"',
f"--dataset.root={dataset_dir}",
"--dataset.push_to_hub=false",
]
)
assert isinstance(loaded_policy, PeftModel)
+2 -1
View File
@@ -21,8 +21,9 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
from lerobot.scripts.lerobot_record import RecordConfig, record
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
from tests.fixtures.constants import DUMMY_REPO_ID
+338
View File
@@ -0,0 +1,338 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Minimal tests for the rollout module's public API."""
from __future__ import annotations
import dataclasses
from unittest.mock import MagicMock
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# ---------------------------------------------------------------------------
# Import smoke tests
# ---------------------------------------------------------------------------
def test_rollout_top_level_imports():
import lerobot.rollout
for name in lerobot.rollout.__all__:
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
def test_inference_submodule_imports():
import lerobot.rollout.inference
for name in lerobot.rollout.inference.__all__:
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
def test_strategies_submodule_imports():
import lerobot.rollout.strategies
for name in lerobot.rollout.strategies.__all__:
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
def test_strategy_config_types():
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
HighlightStrategyConfig,
SentryStrategyConfig,
)
assert BaseStrategyConfig().type == "base"
assert SentryStrategyConfig().type == "sentry"
assert HighlightStrategyConfig().type == "highlight"
assert DAggerStrategyConfig().type == "dagger"
def test_dagger_config_invalid_input_device():
from lerobot.rollout import DAggerStrategyConfig
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
DAggerStrategyConfig(input_device="joystick")
def test_dagger_config_defaults():
from lerobot.rollout import DAggerStrategyConfig
cfg = DAggerStrategyConfig()
assert cfg.num_episodes == 10
assert cfg.record_autonomous is False
assert cfg.input_device == "keyboard"
def test_inference_config_types():
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
assert SyncInferenceConfig().type == "sync"
rtc = RTCInferenceConfig()
assert rtc.type == "rtc"
assert rtc.queue_threshold == 30
assert rtc.rtc is not None
def test_sentry_config_defaults():
from lerobot.rollout import SentryStrategyConfig
cfg = SentryStrategyConfig()
assert cfg.upload_every_n_episodes == 5
assert cfg.target_video_file_size_mb is None
# ---------------------------------------------------------------------------
# RolloutRingBuffer
# ---------------------------------------------------------------------------
def test_ring_buffer_append_and_eviction():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
# max_frames = 5
for i in range(8):
buf.append({"val": i})
assert len(buf) == 5
def test_ring_buffer_drain():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
for i in range(3):
buf.append({"val": i})
frames = buf.drain()
assert len(frames) == 3
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_clear():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
buf.append({"val": 1})
buf.clear()
assert len(buf) == 0
assert buf.estimated_bytes == 0
def test_ring_buffer_tensor_bytes():
from lerobot.rollout import RolloutRingBuffer
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
buf.append({"tensor": t})
assert buf.estimated_bytes >= 400
# ---------------------------------------------------------------------------
# ThreadSafeRobot
# ---------------------------------------------------------------------------
def test_thread_safe_robot_delegates():
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
obs = wrapper.get_observation()
assert "motor_1.pos" in obs
assert "motor_2.pos" in obs
assert "motor_3.pos" in obs
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
result = wrapper.send_action(action)
assert result == action
robot.disconnect()
def test_thread_safe_robot_properties():
from lerobot.rollout import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
robot = MockRobot(MockRobotConfig(n_motors=3))
robot.connect()
wrapper = ThreadSafeRobot(robot)
assert wrapper.name == "mock_robot"
assert "motor_1.pos" in wrapper.observation_features
assert "motor_1.pos" in wrapper.action_features
assert wrapper.is_connected is True
assert wrapper.inner is robot
robot.disconnect()
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------
def test_create_strategy_dispatches():
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
def test_create_strategy_unknown_raises():
from lerobot.rollout.strategies import create_strategy
cfg = MagicMock()
cfg.type = "bogus"
with pytest.raises(ValueError, match="Unknown strategy type"):
create_strategy(cfg)
# ---------------------------------------------------------------------------
# Inference factory
# ---------------------------------------------------------------------------
def test_create_inference_engine_sync():
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
engine = create_inference_engine(
SyncInferenceConfig(),
policy=MagicMock(),
preprocessor=MagicMock(),
postprocessor=MagicMock(),
robot_wrapper=MagicMock(robot_type="mock"),
hw_features={},
dataset_features={},
ordered_action_keys=["k"],
task="test",
fps=30.0,
device="cpu",
)
assert isinstance(engine, SyncInferenceEngine)
# ---------------------------------------------------------------------------
# Pure functions
# ---------------------------------------------------------------------------
def test_estimate_max_episode_seconds_no_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
assert estimate_max_episode_seconds({}, fps=30.0) == 600.0
def test_estimate_max_episode_seconds_with_video():
from lerobot.rollout.strategies import estimate_max_episode_seconds
features = {"cam": {"dtype": "video", "shape": (3, 480, 640)}}
result = estimate_max_episode_seconds(features, fps=30.0)
assert result > 0
# With a real camera, duration should differ from the fallback
assert result != 600.0
def test_safe_push_to_hub():
from lerobot.rollout.strategies import safe_push_to_hub
ds = MagicMock()
ds.num_episodes = 0
assert safe_push_to_hub(ds) is False
ds.push_to_hub.assert_not_called()
ds.num_episodes = 5
assert safe_push_to_hub(ds, tags=["test"]) is True
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
# ---------------------------------------------------------------------------
# DAgger state machine
# ---------------------------------------------------------------------------
def test_dagger_full_transition_cycle():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
assert events.phase == DAggerPhase.AUTONOMOUS
# AUTONOMOUS -> PAUSED
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
# PAUSED -> CORRECTING
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
# CORRECTING -> PAUSED
events.request_transition("correction")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
# PAUSED -> AUTONOMOUS
events.request_transition("pause_resume")
old, new = events.consume_transition()
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
def test_dagger_invalid_transition_ignored():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("correction") # Not valid from AUTONOMOUS
assert events.consume_transition() is None
assert events.phase == DAggerPhase.AUTONOMOUS
def test_dagger_events_reset():
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
events = DAggerEvents()
events.request_transition("pause_resume")
events.consume_transition() # -> PAUSED
events.upload_requested.set()
events.reset()
assert events.phase == DAggerPhase.AUTONOMOUS
assert not events.upload_requested.is_set()
# ---------------------------------------------------------------------------
# Context dataclass
# ---------------------------------------------------------------------------
def test_rollout_context_fields():
from lerobot.rollout import RolloutContext
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
+1 -1
View File
@@ -24,7 +24,7 @@ import pytest
pytest.importorskip("grpc")
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test