mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e21268c29 | |||
| 4130d4a4a5 | |||
| 47bb840a55 | |||
| 9519ff5e09 | |||
| 32a27cae8a | |||
| 5c43fa1cce | |||
| 3f16d98a9b | |||
| 52f508c51c | |||
| a8b72d9615 | |||
| 760220d532 | |||
| a99943ca26 | |||
| 8cee56e2d6 | |||
| a76874f35e | |||
| a9821af61b | |||
| 35bb2c7459 | |||
| 051f6c6803 | |||
| d4a229444b | |||
| 098ebb4d72 | |||
| 9bc2df80bb | |||
| 04ae0312a2 | |||
| cc634de9e7 | |||
| 3eda5712d3 | |||
| 783ec6e232 | |||
| 4e3175ff15 | |||
| edd7fc52a8 | |||
| 0f0f8b8961 | |||
| bd74f6733d | |||
| 79db54dc34 | |||
| 6f4a96333e | |||
| 6ae07878f7 | |||
| 9021d2d240 | |||
| 10d05e03bc | |||
| 60e7d67cb8 | |||
| f2c29d78cf | |||
| 8bc47e4318 | |||
| 49f32b9796 | |||
| f55782f9f7 | |||
| 05a2604d6e |
@@ -2,11 +2,6 @@
|
||||
|
||||
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
|
||||
|
||||
## Type / Scope
|
||||
|
||||
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
|
||||
- **Scope**: (optional — name of module or package affected)
|
||||
|
||||
## Summary / Motivation
|
||||
|
||||
- One-paragraph description of what changes and why.
|
||||
@@ -19,28 +14,14 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
|
||||
|
||||
## What changed
|
||||
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short, concrete bullets explaining the functional changes (how the behavior or output differs now).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
|
||||
## How was this tested (or how to run locally)
|
||||
|
||||
- Tests added: list new tests or test files.
|
||||
- Tests added: list new tests or test files. `pytest -q tests/ -k <keyword>`
|
||||
- Manual checks / dataset runs performed.
|
||||
- Instructions for the reviewer
|
||||
|
||||
Example:
|
||||
|
||||
- Ran the relevant tests:
|
||||
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Reproduce with a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
- Instructions for the reviewer for reproducing with a quick example or CLI (if applicable)
|
||||
|
||||
## Checklist (required before merge)
|
||||
|
||||
@@ -48,6 +29,7 @@ Example:
|
||||
- [ ] All tests pass locally (`pytest`)
|
||||
- [ ] Documentation updated
|
||||
- [ ] CI is green
|
||||
- [ ] Community Review: I have reviewed another contributor's open PR and linked it here: # (insert PR number/link)
|
||||
|
||||
## Reviewer notes
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
@@ -217,6 +217,24 @@ jobs:
|
||||
- name: Run end-to-end tests
|
||||
run: make test-end-to-end
|
||||
|
||||
slack-notification:
|
||||
name: Slack Notification
|
||||
needs: [cpu-tests, gpu-tests, upgrade-lock]
|
||||
if: always() && needs.upgrade-lock.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_SLACK_CHANNEL }}
|
||||
steps:
|
||||
- name: Post to a Slack channel
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@a88e7fa2eaee28de5a4d6142381b1fb792349b67 # main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: "Results of the latest dependency tests (CPU + GPU)"
|
||||
status: ${{ (needs.cpu-tests.result == 'success' && needs.gpu-tests.result == 'success') && 'success' || 'failure' }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
# This job creates or updates a PR with the upgraded lockfile
|
||||
open-pr:
|
||||
name: Open PR
|
||||
|
||||
+4
-1
@@ -78,6 +78,9 @@ Use the templates for required fields and examples.
|
||||
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
> [!IMPORTANT]
|
||||
> Community Review Policy: To help scale our efforts and foster a collaborative environment, we ask contributors to review at least one other person's open PR before their own receives attention. This shared responsibility multiplies our review capacity and helps everyone's code get merged faster!
|
||||
|
||||
Once you have submitted your PR and completed a peer review, a member of the LeRobot team will review your contribution.
|
||||
|
||||
Thank you for contributing to LeRobot!
|
||||
|
||||
@@ -61,6 +61,8 @@
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
|
||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
||||
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
|
||||
## Script
|
||||
|
||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | ---------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -111,8 +111,7 @@ python examples/hil/hil_data_collection.py \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=2
|
||||
```
|
||||
|
||||
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
|
||||
For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=20 \
|
||||
--inference.rtc.max_guidance_weight=5.0 \
|
||||
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -139,8 +138,7 @@ python examples/hil/hil_data_collection.py \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=3
|
||||
```
|
||||
|
||||
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
||||
|
||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||
|
||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||
|
||||
|
||||
+32
-105
@@ -32,6 +32,12 @@ Once you’ve gathered enough trajectories, you’ll train a neural network to i
|
||||
|
||||
If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
|
||||
|
||||
<Tip>
|
||||
|
||||
Want to quickly get the right commands for your setup? The [quickstart notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) [](https://colab.research.google.com/github/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) lets you configure your robot once and generates all the commands below ready to paste.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Set up and Calibrate
|
||||
|
||||
If you haven't yet set up and calibrated your robot and teleop device, please do so by following the robot-specific tutorial.
|
||||
@@ -503,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_awesome_leader_arm \
|
||||
--policy.path=${HF_USER}/my_policy
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Run the policy inference loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
<hfoption id="Sentry mode (with recording)">
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--duration=600
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
The `--strategy.type` flag selects the execution mode:
|
||||
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
# Policy Deployment (lerobot-rollout)
|
||||
|
||||
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||
|
||||
## Quick Start
|
||||
|
||||
No extra dependencies are needed beyond your robot and policy extras.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/act_koch_real \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--task="pick up cube" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
This runs the policy for 30 seconds with no recording.
|
||||
|
||||
---
|
||||
|
||||
## Strategies
|
||||
|
||||
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||
|
||||
### Base (`--strategy.type=base`)
|
||||
|
||||
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the box" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ---------------- | ------------------------------------------------------ |
|
||||
| `--duration` | Run time in seconds (0 = infinite) |
|
||||
| `--task` | Task description passed to the policy |
|
||||
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||
|
||||
### Sentry (`--strategy.type=sentry`)
|
||||
|
||||
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||
|
||||
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/eval_data \
|
||||
--dataset.single_task="Put lego brick into the box" \
|
||||
--duration=3600
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ----------------------------------------------------------- |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||
|
||||
### Highlight (`--strategy.type=highlight`)
|
||||
|
||||
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=highlight \
|
||||
--strategy.ring_buffer_seconds=30 \
|
||||
--strategy.save_key=s \
|
||||
--strategy.push_key=h \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--dataset.repo_id=${HF_USER}/highlight_data \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ------------------ | -------------------------------------------------------- |
|
||||
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||
| `h` (configurable) | Push dataset to Hub |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ---------------------------------------------- |
|
||||
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||
|
||||
### DAgger (`--strategy.type=dagger`)
|
||||
|
||||
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||
|
||||
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||
|
||||
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.record_autonomous=true \
|
||||
--strategy.num_episodes=50 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/dagger_data \
|
||||
--dataset.single_task="Grasp the block"
|
||||
```
|
||||
|
||||
**Keyboard controls** (default input device):
|
||||
|
||||
| Key | Action |
|
||||
| ------- | ------------------------------------------- |
|
||||
| `Space` | Pause / resume policy execution |
|
||||
| `Tab` | Start / stop human correction |
|
||||
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------ | ------------------------------------------------------- |
|
||||
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||
|
||||
### Sync (default)
|
||||
|
||||
One policy call per control tick. The main loop blocks until the action is computed.
|
||||
|
||||
Works with all policies. No extra flags needed.
|
||||
|
||||
### Real-Time Chunking (`--inference.type=rtc`)
|
||||
|
||||
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||
|
||||
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--policy.path=${HF_USER}/pi0_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Pick up the cube" \
|
||||
--duration=60 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||
|
||||
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||
|
||||
---
|
||||
|
||||
## Common Flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||
| `--robot.port` | Serial port for the robot | -- |
|
||||
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||
| `--fps` | Control loop frequency | 30 |
|
||||
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||
| `--resume` | Resume a previous recording session | false |
|
||||
| `--play_sounds` | Vocal synthesis for events | true |
|
||||
|
||||
---
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||
|
||||
```python
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=my_robot_config,
|
||||
policy=my_policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=30,
|
||||
duration=60,
|
||||
task="my task",
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=my_custom_action_processor, # optional
|
||||
robot_observation_processor=my_custom_obs_processor, # optional
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
```
|
||||
|
||||
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||
+7
-3
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
||||
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
|
||||
## Testing RTC with a Real Robot
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
|
||||
# ... create plots
|
||||
```
|
||||
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -284,7 +284,7 @@ python examples/rtc/eval_with_real_robot.py \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
--inference.type=rtc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
logger.info("[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("[HIL] Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("[HIL] End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("[HIL] Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
logger.info(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
logger.info("[HIL] RESET")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
logger.info("Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
||||
logger.info(
|
||||
"%s\n Controls:\n"
|
||||
" SPACE - Pause policy\n"
|
||||
" c - Take control\n"
|
||||
" p - Resume policy after pause/correction\n"
|
||||
" → - End episode\n"
|
||||
" ESC - Stop and push to hub",
|
||||
mode,
|
||||
)
|
||||
+62
-31
@@ -14,17 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
@@ -83,43 +90,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot
|
||||
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
|
||||
@@ -45,9 +45,6 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -77,6 +74,10 @@ def main():
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||
make_default_processors()
|
||||
)
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
@@ -87,14 +88,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -106,13 +107,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). For a CLI entry point with the same capabilities plus
|
||||
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||
"""
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||
# by ``build_rollout_context`` to reload the full model.
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||
# No custom processors here — LeKiwi runs on raw joint features.
|
||||
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,342 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🤗 LeRobot Quickstart\n",
|
||||
"\n",
|
||||
"Calibration → teleoperation → data collection → training → evaluation.\n",
|
||||
"\n",
|
||||
"Install the required dependencies: `pip install -e .[notebook,dataset,training,viz,hardware]`.\n",
|
||||
"\n",
|
||||
"**How to use:**\n",
|
||||
"1. Edit the **Configuration** cell with your settings.\n",
|
||||
"2. Run all cells (`Run All`).\n",
|
||||
"3. Each section prints a ready-to-paste terminal command - copy it and run it.\n",
|
||||
"\n",
|
||||
"Each setup is different, please refer to the [LeRobot documentation](https://huggingface.co/docs/lerobot/il_robots) for more details on each step and available options. <br>\n",
|
||||
"Feel free to make this notebook your own and adapt it to your needs!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## Utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _cameras_arg(cameras: dict) -> str:\n",
|
||||
" if not cameras:\n",
|
||||
" return \"\"\n",
|
||||
" entries = [f\"{n}: {{{', '.join(f'{k}: {v}' for k, v in cfg.items())}}}\" for n, cfg in cameras.items()]\n",
|
||||
" return \"{ \" + \", \".join(entries) + \" }\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_cmd(*parts: str) -> None:\n",
|
||||
" \"\"\"Print a shell command with line continuations, skipping empty parts.\"\"\"\n",
|
||||
" non_empty = [p for p in parts if p]\n",
|
||||
" print(\" \\\\\\n \".join(non_empty))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## Configuration\n",
|
||||
"\n",
|
||||
"Edit this cell, then **Run All** to generate all commands below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Robot (follower) - run `lerobot-find-port` to discover the port\n",
|
||||
"ROBOT_TYPE = \"so101_follower\"\n",
|
||||
"ROBOT_PORT = \"/dev/ttyACM0\"\n",
|
||||
"ROBOT_ID = \"my_follower_arm\"\n",
|
||||
"\n",
|
||||
"# Teleop (leader) - run `lerobot-find-port` to discover the port\n",
|
||||
"TELEOP_TYPE = \"so101_leader\"\n",
|
||||
"TELEOP_PORT = \"/dev/ttyACM1\"\n",
|
||||
"TELEOP_ID = \"my_leader_arm\"\n",
|
||||
"\n",
|
||||
"# Cameras - set to {} to disable\n",
|
||||
"# Run `lerobot-find-cameras opencv` to list available cameras and their indices\n",
|
||||
"CAMERAS = {\n",
|
||||
" \"top\": {\"type\": \"opencv\", \"index_or_path\": 2, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
|
||||
" \"wrist\": {\"type\": \"opencv\", \"index_or_path\": 4, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
"\n",
|
||||
"# Training\n",
|
||||
"POLICY_TYPE = \"act\" # act, diffusion, smolvla, ...\n",
|
||||
"POLICY_DEVICE = \"cuda\" # cuda / cpu / mps\n",
|
||||
"TRAIN_STEPS = 10_000\n",
|
||||
"SAVE_FREQ = 2_000\n",
|
||||
"OUTPUT_DIR = f\"outputs/train/{DATASET_NAME}\"\n",
|
||||
"\n",
|
||||
"# Inference - Hub repo ID or local checkpoint path\n",
|
||||
"# e.g. set to f\"{OUTPUT_DIR}/checkpoints/last\" to use a local checkpoint\n",
|
||||
"POLICY_PATH = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
|
||||
"LAST_CHECKPOINT_PATH = f\"{OUTPUT_DIR}/checkpoints/last\"\n",
|
||||
"\n",
|
||||
"# Derived\n",
|
||||
"DATASET_REPO_ID = f\"{HF_USER}/{DATASET_NAME}\"\n",
|
||||
"DATASET_ROOT = f\"data/{DATASET_NAME}\"\n",
|
||||
"POLICY_REPO_ID = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
|
||||
"EVAL_REPO_ID = f\"{HF_USER}/eval_{DATASET_NAME}\"\n",
|
||||
"CAMERAS_ARG = _cameras_arg(CAMERAS)\n",
|
||||
"CAMERAS_FLAG = f'--robot.cameras=\"{CAMERAS_ARG}\"' if CAMERAS_ARG else \"\"\n",
|
||||
"\n",
|
||||
"print(f\"Robot : {ROBOT_TYPE} @ {ROBOT_PORT}\")\n",
|
||||
"print(f\"Teleop : {TELEOP_TYPE} @ {TELEOP_PORT}\")\n",
|
||||
"print(f\"Cameras: {list(CAMERAS) or 'none'}\")\n",
|
||||
"print(f\"Dataset: {DATASET_REPO_ID} ({NUM_EPISODES} episodes) saved to {DATASET_ROOT}\")\n",
|
||||
"print(f\"Policy : {POLICY_TYPE} -> {POLICY_REPO_ID}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 1. Calibration\n",
|
||||
"\n",
|
||||
"Run once per arm before first use."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Follower\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-calibrate\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Leader\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-calibrate\",\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 2. Teleoperation\n",
|
||||
"\n",
|
||||
"See the [teleoperation docs](https://huggingface.co/docs/lerobot/il_robots#teleoperate) and the [cameras guide](https://huggingface.co/docs/lerobot/cameras) for more options."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-teleoperate\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Record Dataset\n",
|
||||
"\n",
|
||||
"See the [recording docs](https://huggingface.co/docs/lerobot/il_robots#record-a-dataset) for tips on gathering good data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Resume a previously interrupted recording session\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--dataset.root={DATASET_ROOT}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
" \"--resume=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 4. Train Policy\n",
|
||||
"\n",
|
||||
"See the [training docs](https://huggingface.co/docs/lerobot/il_robots#train-a-policy) for configuration options and tips."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-train\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--policy.type={POLICY_TYPE}\",\n",
|
||||
" f\"--policy.device={POLICY_DEVICE}\",\n",
|
||||
" f\"--policy.repo_id={POLICY_REPO_ID}\",\n",
|
||||
" f\"--output_dir={OUTPUT_DIR}\",\n",
|
||||
" f\"--steps={TRAIN_STEPS}\",\n",
|
||||
" f\"--save_freq={SAVE_FREQ}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Resume a previously interrupted training session\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-train\",\n",
|
||||
" f\"--config_path={LAST_CHECKPOINT_PATH}/pretrained_model/train_config.json\",\n",
|
||||
" \"--resume=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 5. Inference\n",
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={EVAL_REPO_ID}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot (3.12.3)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -65,14 +65,15 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
@@ -94,7 +95,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
# Build pipeline to convert EE action to joints action (IK).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -107,7 +108,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
# Build pipeline to convert joint observation to EE observation (FK).
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -118,13 +119,12 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
@@ -163,14 +163,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -182,13 +182,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||
|
||||
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||
with phone teleoperation in EE space, so at deployment we only need the
|
||||
joint↔EE conversion on the robot side; the phone is not used.
|
||||
|
||||
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||
(inline policy call). For recording during rollout, switch to Sentry,
|
||||
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Peek at motor names once to build the kinematic solver.
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,673 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.left_arm_config.max_relative_target=8.0 \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.right_arm_config.max_relative_target=8.0 \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--interpolation_multiplier=3 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--device=cuda
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: Tensor,
|
||||
current_state: Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> Tensor:
|
||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
||||
|
||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
||||
the policy we need to re-express it relative to the *current* robot state
|
||||
and then re-normalize.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
||||
# not the full pos/vel/torque state the robot exposes.
|
||||
observation_features_hw = {
|
||||
key: value
|
||||
for key, value in robot.observation_features().items()
|
||||
if key.endswith(".pos") or isinstance(value, tuple)
|
||||
}
|
||||
|
||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is not None:
|
||||
if relative_step.action_names is None:
|
||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
relative_step.action_names = [
|
||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Re-anchor leftover actions for relative-action policies.
|
||||
# We need the *postprocessed* (absolute) leftover, not the original
|
||||
# (normalized/relative) one that get_left_over() returns.
|
||||
if (
|
||||
prev_actions is not None
|
||||
and relative_step is not None
|
||||
and OBS_STATE in obs_with_policy_features
|
||||
):
|
||||
with action_queue.lock:
|
||||
if action_queue.queue is not None:
|
||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
||||
else:
|
||||
prev_actions_abs = None
|
||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_abs,
|
||||
current_state=obs_with_policy_features[OBS_STATE],
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -62,21 +62,20 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
# Build pipeline to convert follower joints to EE observation.
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -87,7 +86,7 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
# Build pipeline to convert leader joints to EE action.
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -98,9 +97,9 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
steps=[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
@@ -115,13 +114,12 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
@@ -144,7 +142,7 @@ def main():
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
init_rerun(session_name="recording_so100_ee")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -160,14 +158,14 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -179,13 +177,13 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). The custom observation/action processors convert between
|
||||
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||
forward/inverse kinematics.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Policy config (full model is loaded inside build_rollout_context).
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||
# wrong observation/action space.
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+15
-8
@@ -108,9 +108,9 @@ training = [
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
"lerobot[pynput-dep]",
|
||||
"lerobot[pyserial-dep]",
|
||||
"lerobot[deepdiff-dep]",
|
||||
]
|
||||
viz = [
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
@@ -136,10 +136,14 @@ scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
damiao = ["lerobot[can-dep]"]
|
||||
robstride = ["lerobot[can-dep]"]
|
||||
|
||||
@@ -147,10 +151,11 @@ robstride = ["lerobot[can-dep]"]
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
lekiwi = ["lerobot[feetech]", "lerobot[pyzmq-dep]"]
|
||||
unitree_g1 = [
|
||||
# "unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"lerobot[pyzmq-dep]",
|
||||
"lerobot[pyserial-dep]",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"onnx>=1.16.0,<2.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
@@ -196,7 +201,8 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
@@ -269,6 +275,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -33,7 +33,7 @@ import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
@@ -76,6 +76,7 @@ class Reachy2Camera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -19,16 +19,18 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
from lerobot.utils.import_utils import _pyrealsense2_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _pyrealsense2_available:
|
||||
import pyrealsense2 as rs
|
||||
else:
|
||||
rs = None
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
@@ -112,7 +114,7 @@ class RealSenseCamera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
|
||||
require_package("pyrealsense2", extra="intelrealsense")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -28,12 +28,19 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.import_utils import _zmq_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _zmq_available:
|
||||
import zmq
|
||||
else:
|
||||
zmq = None
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -74,8 +81,8 @@ class ZMQCamera(Camera):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZMQCameraConfig):
|
||||
require_package("pyzmq", extra="pyzmq-dep", import_name="zmq")
|
||||
super().__init__(config)
|
||||
import zmq
|
||||
|
||||
self.config = config
|
||||
self.server_address = config.server_address
|
||||
@@ -117,8 +124,6 @@ class ZMQCamera(Camera):
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
try:
|
||||
import zmq
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
@@ -180,11 +185,8 @@ class ZMQCamera(Camera):
|
||||
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
except zmq.Again as e:
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
|
||||
# Decode JSON message
|
||||
data = json.loads(message)
|
||||
|
||||
@@ -28,6 +28,12 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
|
||||
from lerobot.utils.import_utils import _deepdiff_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _deepdiff_available:
|
||||
from deepdiff import DeepDiff
|
||||
else:
|
||||
DeepDiff = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
@@ -217,10 +223,7 @@ def sanity_check_dataset_robot_compatibility(
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("deepdiff", extra="hardware")
|
||||
from deepdiff import DeepDiff
|
||||
require_package("deepdiff", extra="deepdiff-dep")
|
||||
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .types import (
|
||||
@@ -39,6 +40,7 @@ __all__ = [
|
||||
"PolicyFeature",
|
||||
"RTCAttentionSchedule",
|
||||
# Config classes
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"PeftConfig",
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str = ""
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str = ""
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.repo_id:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||
@@ -35,6 +35,9 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@@ -56,6 +56,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -49,6 +50,7 @@ class DatasetReader:
|
||||
video_backend: str,
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -73,6 +75,7 @@ class DatasetReader:
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -105,10 +108,8 @@ class DatasetReader:
|
||||
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
|
||||
self._absolute_to_relative_idx = None
|
||||
if self.episodes is not None and self.hf_dataset is not None:
|
||||
self._absolute_to_relative_idx = {
|
||||
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
|
||||
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
|
||||
}
|
||||
indices = self.hf_dataset.data.column("index").to_numpy()
|
||||
self._absolute_to_relative_idx = dict(zip(indices.tolist(), range(len(indices)), strict=True))
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
@@ -235,16 +236,30 @@ class DatasetReader:
|
||||
Segmentation Fault.
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
return item
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
# Single camera: no threading overhead
|
||||
if len(items) <= 1:
|
||||
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
|
||||
|
||||
# Multi-camera: decode in parallel (video decoding releases the GIL)
|
||||
with ThreadPoolExecutor(max_workers=len(items)) as pool:
|
||||
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
|
||||
return dict(f.result() for f in futures)
|
||||
|
||||
def get_item(self, idx) -> dict:
|
||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||
|
||||
@@ -597,7 +597,7 @@ class DatasetWriter:
|
||||
|
||||
def cleanup_interrupted_episode(self, episode_index: int) -> None:
|
||||
"""Remove temporary image directories for an interrupted episode."""
|
||||
for key in self._meta.video_keys:
|
||||
for key in self._meta.camera_keys:
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
|
||||
@@ -92,6 +92,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
@@ -104,6 +105,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
|
||||
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
except BaseException:
|
||||
dataset = kwargs.get("dataset")
|
||||
writer = getattr(dataset, "writer", None) if dataset else None
|
||||
if writer is not None and writer.image_writer is not None:
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
writer.image_writer.stop()
|
||||
raise e
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
@@ -202,6 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
@@ -225,6 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend=self._video_backend,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -288,6 +291,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend=self._video_backend,
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -683,6 +687,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
@@ -775,6 +780,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -251,6 +251,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -288,6 +289,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
@@ -553,7 +555,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
video_path,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
@@ -71,8 +71,8 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
|
||||
@@ -123,6 +123,7 @@ def decode_video_frames(
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -131,19 +132,23 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -154,6 +159,7 @@ def decode_video_frames_torchvision(
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
@@ -240,14 +246,17 @@ def decode_video_frames_torchvision(
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
return closest_frames
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -306,6 +315,7 @@ def decode_video_frames_torchcodec(
|
||||
tolerance_s: float,
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
@@ -373,14 +383,16 @@ def decode_video_frames_torchcodec(
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
|
||||
if not len(timestamps) == len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||
)
|
||||
|
||||
if return_uint8:
|
||||
return closest_frames
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
return closest_frames
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics using placo library for forward and inverse kinematics."""
|
||||
@@ -32,13 +43,7 @@ class RobotKinematics:
|
||||
target_frame_name (str): Name of the end-effector frame in the URDF
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
try:
|
||||
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"placo is required for RobotKinematics. "
|
||||
"Please install the optional dependencies of `kinematics` in the package."
|
||||
) from e
|
||||
require_package("placo", extra="placo-dep")
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -24,7 +24,7 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
from lerobot.utils.import_utils import _can_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
@@ -111,6 +111,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
require_package("python-can", extra="damiao", import_name="can")
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
|
||||
@@ -356,8 +356,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("deepdiff", extra="hardware")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
require_package("deepdiff", extra="deepdiff-dep")
|
||||
super().__init__(port, motors, calibration)
|
||||
|
||||
self.port_handler: PortHandler
|
||||
|
||||
@@ -23,12 +23,12 @@ from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
from lerobot.utils.import_utils import _can_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
can = SimpleNamespace(Message=object, interface=None)
|
||||
can = SimpleNamespace(Message=object, interface=None, BusABC=object)
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
@@ -106,6 +106,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
require_package("python-can", extra="robstride", import_name="can")
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
|
||||
@@ -18,14 +18,21 @@ import logging
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.optimization import get_scheduler
|
||||
else:
|
||||
get_scheduler = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -47,10 +54,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="diffusion")
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
@@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
@@ -23,6 +23,7 @@ TODO(alexander-soare):
|
||||
import math
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -32,6 +33,14 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
else:
|
||||
DDIMScheduler = None
|
||||
DDPMScheduler = None
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import (
|
||||
@@ -64,6 +73,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
@@ -155,11 +165,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict):
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="diffusion")
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
|
||||
@@ -43,6 +43,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_groot import GrootConfig
|
||||
@@ -59,6 +60,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
def __init__(self, config: GrootConfig, **kwargs):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
require_package("transformers", extra="groot")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -36,7 +36,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
|
||||
|
||||
from .configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
|
||||
@@ -46,6 +46,13 @@ if TYPE_CHECKING or _transformers_available:
|
||||
else:
|
||||
CLIPTextModel = None
|
||||
CLIPVisionModel = None
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
else:
|
||||
DDIMScheduler = None
|
||||
DDPMScheduler = None
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
@@ -65,6 +72,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
|
||||
require_package("transformers", extra="multi_task_dit")
|
||||
require_package("diffusers", extra="multi_task_dit")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
@@ -643,12 +652,6 @@ class DiffusionObjective(nn.Module):
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="multi_task_dit")
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -947,6 +947,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -918,6 +918,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _scipy_available:
|
||||
@@ -35,7 +35,7 @@ else:
|
||||
idct = None
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from ..pi_gemma import (
|
||||
@@ -44,6 +44,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
@@ -826,14 +827,14 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
require_package("scipy", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Load tokenizers first
|
||||
try:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
# Load FAST tokenizer
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
config.action_tokenizer_name, trust_remote_code=True
|
||||
|
||||
@@ -1,116 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
__all__ = ["ActionInterpolator"]
|
||||
|
||||
@@ -92,10 +92,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
return len(self.queue) - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
@@ -103,11 +103,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
return len(self.queue) - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
@@ -115,7 +114,8 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
with self.lock:
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
@@ -62,6 +62,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.device_utils import get_safe_dtype
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
@@ -239,6 +240,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
the configuration class is used.
|
||||
"""
|
||||
|
||||
require_package("transformers", extra="smolvla")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch.distributed as distributed
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from einops import pack, rearrange, reduce, repeat, unpack
|
||||
from torch import einsum, nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .configuration_vqbet import VQBeTConfig
|
||||
@@ -1370,7 +1370,7 @@ class EuclideanCodebook(nn.Module):
|
||||
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
||||
self.replace(batch_samples, batch_mask=expired_codes)
|
||||
|
||||
@autocast(enabled=False)
|
||||
@autocast("cuda", enabled=False)
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
|
||||
@@ -76,6 +76,7 @@ from lerobot.transport.utils import (
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
@@ -94,7 +95,6 @@ from .gym_manipulator import (
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .process import ProcessSignalHandler
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
# Main entry point
|
||||
|
||||
@@ -90,6 +90,7 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
@@ -99,7 +100,6 @@ from lerobot.utils.utils import (
|
||||
|
||||
from .buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
from .process import ProcessSignalHandler
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -81,6 +81,7 @@ class Reachy2Robot(Robot):
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,7 @@ import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available, require_package
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
@@ -111,6 +111,7 @@ class UnitreeG1(Robot):
|
||||
name = "unitree_g1"
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
require_package("unitree-sdk2py", extra="unitree_g1", import_name="unitree_sdk2py")
|
||||
super().__init__(config)
|
||||
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy deployment engine with pluggable rollout strategies."""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("datasets", extra="dataset")
|
||||
|
||||
from .configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
DatasetRecordConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
from .context import (
|
||||
DatasetContext,
|
||||
HardwareContext,
|
||||
PolicyContext,
|
||||
ProcessorContext,
|
||||
RolloutContext,
|
||||
RuntimeContext,
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
SyncInferenceEngine,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .ring_buffer import RolloutRingBuffer
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
from .strategies import RolloutStrategy, create_strategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategyConfig",
|
||||
"DAggerKeyboardConfig",
|
||||
"DAggerPedalConfig",
|
||||
"DAggerStrategyConfig",
|
||||
"DatasetContext",
|
||||
"DatasetRecordConfig",
|
||||
"HardwareContext",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutRingBuffer",
|
||||
"RolloutStrategy",
|
||||
"RolloutStrategyConfig",
|
||||
"RuntimeContext",
|
||||
"SentryStrategyConfig",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"ThreadSafeRobot",
|
||||
"build_rollout_context",
|
||||
"create_inference_engine",
|
||||
"create_strategy",
|
||||
]
|
||||
@@ -0,0 +1,270 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration dataclasses for the rollout deployment engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Abstract base for rollout strategy configurations.
|
||||
|
||||
Use ``--strategy.type=<name>`` on the CLI to select a strategy.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("base")
|
||||
@dataclass
|
||||
class BaseStrategyConfig(RolloutStrategyConfig):
|
||||
"""Autonomous rollout with no data recording."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("sentry")
|
||||
@dataclass
|
||||
class SentryStrategyConfig(RolloutStrategyConfig):
|
||||
"""Continuous autonomous rollout with always-on recording.
|
||||
|
||||
Episode duration is derived from camera resolution, FPS, and
|
||||
``target_video_file_size_mb`` so that each saved episode produces a
|
||||
video file that has crossed the target size. This aligns episode
|
||||
boundaries with the dataset's video file chunking, so each
|
||||
``push_to_hub`` call uploads complete video files rather than
|
||||
re-uploading a growing file that hasn't crossed the chunk boundary.
|
||||
"""
|
||||
|
||||
upload_every_n_episodes: int = 5
|
||||
# Target video file size in MB for episode rotation. Episodes are
|
||||
# saved once the estimated video duration would exceed this limit.
|
||||
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
|
||||
target_video_file_size_mb: float | None = None
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("highlight")
|
||||
@dataclass
|
||||
class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||
|
||||
A memory-bounded ring buffer continuously captures telemetry. When
|
||||
the user presses the save key, the buffer contents are flushed to
|
||||
the dataset and live recording continues until the key is pressed
|
||||
again.
|
||||
"""
|
||||
|
||||
ring_buffer_seconds: float = 30.0
|
||||
ring_buffer_max_memory_mb: float = 2048.0
|
||||
save_key: str = "s"
|
||||
push_key: str = "h"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerKeyboardConfig:
|
||||
"""Keyboard key bindings for DAgger controls.
|
||||
|
||||
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
|
||||
special key names (``"space"``).
|
||||
"""
|
||||
|
||||
pause_resume: str = "space"
|
||||
correction: str = "tab"
|
||||
upload: str = "enter"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerPedalConfig:
|
||||
"""Foot pedal configuration for DAgger controls.
|
||||
|
||||
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
|
||||
"""
|
||||
|
||||
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
pause_resume: str = "KEY_A"
|
||||
correction: str = "KEY_B"
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
"""Human-in-the-loop data collection (DAgger / RaC).
|
||||
|
||||
Alternates between autonomous policy execution and human intervention.
|
||||
Intervention frames are tagged with ``intervention=True``.
|
||||
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
``input_device``. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — toggle policy execution on/off.
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
When ``record_autonomous=True`` (default) both autonomous and correction
|
||||
frames are recorded with size-based episode rotation (same as Sentry)
|
||||
and background uploading. ``push_to_hub`` is blocked while a correction
|
||||
is in progress. Set to ``False`` to record only the human-correction
|
||||
windows, where each correction becomes its own episode.
|
||||
"""
|
||||
|
||||
num_episodes: int = 10
|
||||
record_autonomous: bool = False
|
||||
upload_every_n_episodes: int = 5
|
||||
# Target video file size in MB for episode rotation (record_autonomous
|
||||
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
|
||||
target_video_file_size_mb: float | None = None
|
||||
input_device: str = "keyboard"
|
||||
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
|
||||
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.input_device not in ("keyboard", "pedal"):
|
||||
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level rollout config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutConfig:
|
||||
"""Top-level configuration for the ``lerobot-rollout`` CLI.
|
||||
|
||||
Combines hardware, policy, strategy, and runtime settings. The
|
||||
``__post_init__`` method performs fail-fast validation to reject
|
||||
invalid flag combinations early.
|
||||
"""
|
||||
|
||||
# Hardware
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
|
||||
# Policy (loaded from --policy.path via __post_init__)
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger)
|
||||
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
|
||||
|
||||
# Inference backend (polymorphic: --inference.type=sync|rtc)
|
||||
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
|
||||
|
||||
# Dataset (required for sentry, highlight, dagger; None for base)
|
||||
dataset: DatasetRecordConfig | None = None
|
||||
|
||||
# Runtime
|
||||
fps: float = 30.0
|
||||
duration: float = 0.0 # 0 = infinite (24/7 mode)
|
||||
interpolation_multiplier: int = 1
|
||||
device: str | None = None
|
||||
task: str = ""
|
||||
display_data: bool = False
|
||||
# Display data on a remote Rerun server
|
||||
display_ip: str | None = None
|
||||
# Port of the remote Rerun server
|
||||
display_port: int | None = None
|
||||
# Whether to display compressed images in Rerun
|
||||
display_compressed_images: bool = False
|
||||
# Use vocal synthesis to read events
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
# Torch compile
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
compile_warmup_inferences: int = 2
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate config invariants and load the policy config from ``--policy.path``."""
|
||||
# --- Strategy-specific validation ---
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
||||
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
||||
|
||||
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
|
||||
raise ValueError(
|
||||
"Base strategy does not record data. Use sentry, highlight, or dagger for recording."
|
||||
)
|
||||
|
||||
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
|
||||
if (
|
||||
isinstance(self.strategy, SentryStrategyConfig)
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("Sentry mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# Highlight writes frames while the policy is still running, so streaming is mandatory.
|
||||
if (
|
||||
isinstance(self.strategy, HighlightStrategyConfig)
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("Highlight mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
|
||||
if (
|
||||
isinstance(self.strategy, DAggerStrategyConfig)
|
||||
and self.strategy.record_autonomous
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# --- Policy loading ---
|
||||
if self.robot is None:
|
||||
raise ValueError("--robot.type is required for rollout")
|
||||
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("--policy.path is required for rollout")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
@@ -0,0 +1,429 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout context: shared state created once before strategy dispatch.
|
||||
|
||||
Grouped into five topical sub-contexts — :class:`RuntimeContext`,
|
||||
:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`,
|
||||
and :class:`DatasetContext` — assembled into :class:`RolloutContext`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
|
||||
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RTCInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_action_key_order(
|
||||
policy_action_names: list[str] | None, dataset_action_names: list[str]
|
||||
) -> list[str]:
|
||||
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
|
||||
if not policy_action_names:
|
||||
return dataset_action_names
|
||||
policy_action_names = list(policy_action_names)
|
||||
if len(policy_action_names) != len(dataset_action_names):
|
||||
logger.warning(
|
||||
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
|
||||
len(policy_action_names),
|
||||
len(dataset_action_names),
|
||||
)
|
||||
return dataset_action_names
|
||||
if set(dataset_action_names) != set(policy_action_names):
|
||||
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
|
||||
return dataset_action_names
|
||||
return policy_action_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sub-contexts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
"""Runtime knobs shared with every strategy."""
|
||||
|
||||
cfg: RolloutConfig
|
||||
shutdown_event: Event
|
||||
|
||||
|
||||
@dataclass
|
||||
class HardwareContext:
|
||||
"""Connected hardware.
|
||||
|
||||
The raw robot is available via ``robot_wrapper.inner`` when needed
|
||||
(e.g. for disconnect); strategies should otherwise go through the
|
||||
thread-safe wrapper.
|
||||
|
||||
``initial_position`` stores the robot's joint positions at connect
|
||||
time. Strategies use it to return the robot to a safe pose before
|
||||
shutting down.
|
||||
"""
|
||||
|
||||
robot_wrapper: ThreadSafeRobot
|
||||
teleop: Teleoperator | None
|
||||
initial_position: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine."""
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorContext:
|
||||
"""Robot-side pipelines (run outside the policy)."""
|
||||
|
||||
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
|
||||
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
|
||||
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetContext:
|
||||
"""Dataset and feature bookkeeping."""
|
||||
|
||||
dataset: LeRobotDataset | None
|
||||
dataset_features: dict = field(default_factory=dict)
|
||||
hw_features: dict = field(default_factory=dict)
|
||||
ordered_action_keys: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutContext:
|
||||
"""Bundle of sub-contexts passed to every rollout strategy.
|
||||
|
||||
Built once by :func:`build_rollout_context` before strategy dispatch.
|
||||
"""
|
||||
|
||||
runtime: RuntimeContext
|
||||
hardware: HardwareContext
|
||||
policy: PolicyContext
|
||||
processors: ProcessorContext
|
||||
data: DatasetContext
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_rollout_context(
|
||||
cfg: RolloutConfig,
|
||||
shutdown_event: Event,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
) -> RolloutContext:
|
||||
"""Wire up policy, processors, hardware, dataset, and inference engine.
|
||||
|
||||
The order is policy-first / hardware-last so a bad ``--policy.path``
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if full_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
teleop_action_processor is None
|
||||
or robot_action_processor is None
|
||||
or robot_observation_processor is None
|
||||
):
|
||||
_t, _r, _o = make_default_processors()
|
||||
teleop_action_processor = teleop_action_processor or _t
|
||||
robot_action_processor = robot_action_processor or _r
|
||||
robot_observation_processor = robot_observation_processor or _o
|
||||
|
||||
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
|
||||
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
logger.info("Robot connected: %s", robot.name)
|
||||
|
||||
# Store the initial joint positions so we can return to a safe pose on shutdown.
|
||||
initial_obs = robot.get_observation()
|
||||
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
|
||||
logger.info("Captured initial robot position (%d keys)", len(initial_position))
|
||||
|
||||
robot_wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
teleop = None
|
||||
if cfg.teleop is not None:
|
||||
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
teleop.connect()
|
||||
logger.info("Teleoperator connected")
|
||||
|
||||
# DAgger requires teleop with motor control capabilities (enable_torque,
|
||||
# disable_torque, write_goal_positions).
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
|
||||
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
|
||||
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
|
||||
# if missing:
|
||||
# teleop.disconnect()
|
||||
# raise ValueError(
|
||||
# f"DAgger strategy requires a teleoperator with motor control methods "
|
||||
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
|
||||
# )
|
||||
|
||||
# --- 4. Features + action-key reconciliation ---------------------
|
||||
all_obs_features = robot.observation_features
|
||||
observation_features_hw = {
|
||||
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
|
||||
}
|
||||
action_features_hw = robot.action_features
|
||||
|
||||
# The action side is always needed: sync inference reads action names from
|
||||
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||
action_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
)
|
||||
# Observation-side aggregation is needed because of build_dataset_frame
|
||||
observation_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=observation_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
)
|
||||
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
raw_action_keys = list(robot.action_features.keys())
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
ordered_action_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
raw_action_keys,
|
||||
)
|
||||
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
|
||||
if not rename_map:
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {
|
||||
f"observation.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||
}
|
||||
policy_subset = expected_visuals.issubset(provided_visuals)
|
||||
hw_subset = provided_visuals.issubset(expected_visuals)
|
||||
if not (policy_subset or hw_subset):
|
||||
raise ValueError(
|
||||
f"Visual feature mismatch between policy and robot hardware.\n"
|
||||
f"Policy expects: {expected_visuals}\n"
|
||||
f"Robot provides: {provided_visuals}"
|
||||
)
|
||||
|
||||
# --- 5. Dataset -------------
|
||||
dataset = None
|
||||
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
|
||||
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
)
|
||||
else:
|
||||
if isinstance(cfg.strategy, DAggerStrategyConfig):
|
||||
dataset_features["intervention"] = {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if dataset is not None:
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.dataset.rename_map if cfg.dataset else {},
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
||||
},
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
"Creating inference engine (type=%s)...",
|
||||
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
|
||||
)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
inference_strategy = create_inference_engine(
|
||||
cfg.inference,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
robot_wrapper=robot_wrapper,
|
||||
hw_features=hw_features,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task_str,
|
||||
fps=cfg.fps,
|
||||
device=cfg.device,
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
logger.info("Rollout context assembled successfully")
|
||||
return RolloutContext(
|
||||
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
|
||||
hardware=HardwareContext(
|
||||
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
|
||||
),
|
||||
policy=PolicyContext(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
inference=inference_strategy,
|
||||
),
|
||||
processors=ProcessorContext(
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
),
|
||||
data=DatasetContext(
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete strategies (sync, RTC, …) expose the same small interface so
|
||||
rollout strategies never branch on the inference backend.
|
||||
"""
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine ABC.
|
||||
|
||||
Rollout strategies consume actions through this small interface so they
|
||||
do not need to know whether the inference engine is synchronous, runs in
|
||||
a background thread (RTC), or comes from an external source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InferenceEngine(abc.ABC):
|
||||
"""Abstract backend for producing actions during rollout.
|
||||
|
||||
Subclasses decide whether inference happens inline, in a background
|
||||
thread, or externally. The contract is minimal so new backends can
|
||||
be added without touching rollout strategies.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
``start`` — prepare the backend (e.g. launch a background thread).
|
||||
``stop`` — shut the backend down cleanly.
|
||||
``reset`` — clear episode-scoped state (policy hidden state, queues…).
|
||||
|
||||
Action production
|
||||
-----------------
|
||||
``get_action(obs_frame)`` — return the next action tensor, or
|
||||
``None`` if none is available (e.g. async queue empty). Sync
|
||||
backends always compute from ``obs_frame``; async backends may
|
||||
ignore it (they get observations via ``notify_observation``).
|
||||
|
||||
Optional hooks
|
||||
--------------
|
||||
``notify_observation`` / ``pause`` / ``resume`` have a no-op default
|
||||
so rollout strategies can invoke them unconditionally.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self) -> None:
|
||||
"""Initialise the backend."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Tear the backend down."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Clear episode-scoped state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Return the next action tensor, or ``None`` if unavailable."""
|
||||
|
||||
def notify_observation(self, obs: dict) -> None: # noqa: B027
|
||||
"""Publish the latest processed observation. Default: no-op."""
|
||||
|
||||
def pause(self) -> None: # noqa: B027
|
||||
"""Pause background inference. Default: no-op."""
|
||||
|
||||
def resume(self) -> None: # noqa: B027
|
||||
"""Resume background inference. Default: no-op."""
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once the backend can produce actions (e.g. warmup done)."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
"""True if an unrecoverable error occurred in the backend."""
|
||||
return False
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceEngine
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Abstract base for inference backend configuration.
|
||||
|
||||
Use ``--inference.type=<name>`` on the CLI to select a backend.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("sync")
|
||||
@dataclass
|
||||
class SyncInferenceConfig(InferenceEngineConfig):
|
||||
"""Inline synchronous inference (one policy call per control tick)."""
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("rtc")
|
||||
@dataclass
|
||||
class RTCInferenceConfig(InferenceEngineConfig):
|
||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||
|
||||
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
||||
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
device=device,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
robot_wrapper=robot_wrapper,
|
||||
rtc_config=config.rtc,
|
||||
hw_features=hw_features,
|
||||
task=task,
|
||||
fps=fps,
|
||||
device=device,
|
||||
use_torch_compile=use_torch_compile,
|
||||
compile_warmup_inferences=compile_warmup_inferences,
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
@@ -0,0 +1,391 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real-Time Chunking inference engine.
|
||||
|
||||
A background thread produces action chunks asynchronously via
|
||||
:meth:`policy.predict_action_chunk`. The main control loop polls
|
||||
``get_action`` for the next ready action; observations flow the other
|
||||
way via ``notify_observation``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
||||
_RTC_IDLE_SLEEP_S: float = 0.01
|
||||
# Backoff between transient inference errors (per consecutive failure).
|
||||
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
||||
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
||||
# Hard timeout for joining the RTC thread on stop().
|
||||
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RTCInferenceEngine(InferenceEngine):
|
||||
"""Async RTC inference: a background thread produces action chunks.
|
||||
|
||||
``get_action`` pops the next action from the shared queue (or
|
||||
returns ``None`` if the queue is empty). The main loop should call
|
||||
``notify_observation`` every tick and ``pause``/``resume`` around
|
||||
human-intervention phases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
rtc_config: RTCConfig,
|
||||
hw_features: dict,
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
rtc_queue_threshold: int = 30,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._robot = robot_wrapper
|
||||
self._rtc_config = rtc_config
|
||||
self._hw_features = hw_features
|
||||
self._task = task
|
||||
self._fps = fps
|
||||
self._device = device or "cpu"
|
||||
self._use_torch_compile = use_torch_compile
|
||||
self._compile_warmup_inferences = compile_warmup_inferences
|
||||
self._rtc_queue_threshold = rtc_queue_threshold
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._obs_holder: dict[str, Any] = {}
|
||||
self._obs_lock = Lock()
|
||||
self._policy_active = Event()
|
||||
self._compile_warmup_done = Event()
|
||||
self._shutdown_event = Event()
|
||||
self._rtc_error = Event()
|
||||
self._global_shutdown_event = shutdown_event
|
||||
self._rtc_thread: Thread | None = None
|
||||
|
||||
if not self._use_torch_compile:
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
|
||||
else:
|
||||
logger.info(
|
||||
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
|
||||
compile_warmup_inferences,
|
||||
)
|
||||
|
||||
# Processor introspection for relative-action re-anchoring.
|
||||
self._relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self._normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if self._relative_step is not None:
|
||||
if self._relative_step.action_names is None:
|
||||
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
self._relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
self._relative_step.action_names = [
|
||||
k for k in robot_wrapper.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
|
||||
return self._compile_warmup_done.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
"""True if the RTC background thread exited due to an unrecoverable error."""
|
||||
return self._rtc_error.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
"""The shared action queue between the RTC thread and the main loop."""
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Launch the RTC background thread."""
|
||||
self._action_queue = ActionQueue(self._rtc_config)
|
||||
self._obs_holder = {
|
||||
"obs": None,
|
||||
"robot_type": self._robot.robot_type,
|
||||
}
|
||||
self._shutdown_event.clear()
|
||||
self._rtc_thread = Thread(
|
||||
target=self._rtc_loop,
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
self._rtc_thread.start()
|
||||
logger.info("RTC inference thread started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the RTC thread to stop and wait for it."""
|
||||
logger.info("Stopping RTC inference thread...")
|
||||
self._shutdown_event.set()
|
||||
self._policy_active.clear()
|
||||
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
||||
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
||||
if self._rtc_thread.is_alive():
|
||||
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
|
||||
else:
|
||||
logger.info("RTC inference thread stopped")
|
||||
self._rtc_thread = None
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the RTC background thread."""
|
||||
logger.info("Pausing RTC inference thread")
|
||||
self._policy_active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume the RTC background thread."""
|
||||
logger.info("Resuming RTC inference thread")
|
||||
self._policy_active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy, processors, and action queue."""
|
||||
logger.info("Resetting RTC inference state (policy + processors + queue)")
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (called from main thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
|
||||
if self._action_queue is None:
|
||||
return None
|
||||
return self._action_queue.get()
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
"""Publish the latest observation for the RTC thread to consume."""
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RTC: background inference thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _rtc_loop(self) -> None:
|
||||
"""Background thread that generates action chunks via RTC."""
|
||||
try:
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / self._fps
|
||||
policy_device = torch.device(self._device)
|
||||
|
||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||
inference_count = 0
|
||||
consecutive_errors = 0
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
if not self._policy_active.is_set():
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= self._rtc_queue_threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
|
||||
obs_batch = prepare_observation_for_inference(
|
||||
obs_batch, policy_device, self._task, self._robot.robot_type
|
||||
)
|
||||
obs_batch["task"] = [self._task]
|
||||
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
state_tensor = preprocessed.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=state_tensor,
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._rtc_config.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = self._postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
|
||||
inference_count += 1
|
||||
consecutive_errors = 0
|
||||
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
||||
if is_warmup:
|
||||
latency_tracker.reset()
|
||||
else:
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
|
||||
if (
|
||||
is_warmup
|
||||
and inference_count >= warmup_required
|
||||
and not self._compile_warmup_done.is_set()
|
||||
):
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("Compile warmup complete (%d inferences)", inference_count)
|
||||
|
||||
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"RTC inference error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
||||
# Persistent failure: stop retrying and propagate shutdown.
|
||||
raise
|
||||
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
||||
else:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Fatal error in RTC thread: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._rtc_error.set()
|
||||
# Unblock any warmup waiters so the main loop doesn't spin forever
|
||||
self._compile_warmup_done.set()
|
||||
# Signal the top-level shutdown so strategies exit their control loops
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Synchronous inference engine: inline policy call per control tick."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncInferenceEngine(InferenceEngine):
|
||||
"""Inline synchronous inference: compute one action per call.
|
||||
|
||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
||||
``select_action``) on the given observation frame and returns a
|
||||
CPU action tensor reordered to match the dataset action keys.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
device: str | None,
|
||||
robot_type: str,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._dataset_features = dataset_features
|
||||
self._ordered_action_keys = ordered_action_keys
|
||||
self._task = task
|
||||
self._device = torch.device(device or "cpu")
|
||||
self._robot_type = robot_type
|
||||
logger.info(
|
||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||
self._device,
|
||||
len(ordered_action_keys),
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""No background resources to start."""
|
||||
logger.info("SyncInferenceEngine started (inline mode — no background thread)")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""No background resources to stop."""
|
||||
logger.info("SyncInferenceEngine stopped")
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy and pre/post-processors."""
|
||||
logger.info("Resetting sync inference state (policy + processors)")
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||
if obs_frame is None:
|
||||
return None
|
||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||
# tensor/array values are not shared with any other reader.
|
||||
observation = copy(obs_frame)
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type=self._device.type)
|
||||
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.inference_mode(), autocast_ctx:
|
||||
observation = prepare_observation_for_inference(
|
||||
observation, self._device, self._task, self._robot_type
|
||||
)
|
||||
observation = self._preprocessor(observation)
|
||||
action = self._policy.select_action(observation)
|
||||
action = self._postprocessor(action)
|
||||
action_tensor = action.squeeze(0).cpu()
|
||||
|
||||
# Reorder to match dataset action ordering so the caller can treat
|
||||
# the returned tensor uniformly across backends.
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Memory-bounded ring buffer for the Highlight Reel rollout strategy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RolloutRingBuffer:
|
||||
"""Fixed-capacity circular buffer for observation/action frames.
|
||||
|
||||
Stores the last *N* seconds of telemetry in memory, bounded by both
|
||||
time (``max_frames``) and memory (``max_memory_bytes``). When either
|
||||
limit is reached the oldest frames are evicted.
|
||||
|
||||
.. note::
|
||||
This class is **single-threaded**. ``append``/``drain``/``clear``
|
||||
must all be called from the same thread (the rollout main loop).
|
||||
Concurrent access from a background thread will corrupt
|
||||
``_current_bytes`` accounting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seconds:
|
||||
Maximum duration of buffered telemetry.
|
||||
max_memory_mb:
|
||||
Hard memory cap in MiB. Frames are evicted when the estimated
|
||||
total size exceeds this.
|
||||
fps:
|
||||
Frames per second — used to convert ``max_seconds`` to a frame
|
||||
count.
|
||||
"""
|
||||
|
||||
def __init__(self, max_seconds: float = 30.0, max_memory_mb: float = 2048.0, fps: float = 30.0) -> None:
|
||||
self._max_frames = int(max_seconds * fps)
|
||||
self._max_bytes = int(max_memory_mb * 1024 * 1024)
|
||||
self._buffer: deque[dict] = deque(maxlen=self._max_frames)
|
||||
self._current_bytes: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def append(self, frame: dict) -> None:
|
||||
"""Add *frame* to the buffer, evicting the oldest if at capacity."""
|
||||
frame_bytes = _estimate_frame_bytes(frame)
|
||||
|
||||
# Evict oldest frames until we are under the memory cap
|
||||
while self._current_bytes + frame_bytes > self._max_bytes and self._buffer:
|
||||
evicted = self._buffer.popleft()
|
||||
self._current_bytes -= _estimate_frame_bytes(evicted)
|
||||
|
||||
self._buffer.append(frame)
|
||||
self._current_bytes += frame_bytes
|
||||
|
||||
def drain(self) -> list[dict]:
|
||||
"""Return all buffered frames and clear the buffer."""
|
||||
frames = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
self._current_bytes = 0
|
||||
return frames
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Discard all buffered frames."""
|
||||
self._buffer.clear()
|
||||
self._current_bytes = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def estimated_bytes(self) -> int:
|
||||
"""Estimated total byte size of all buffered frames."""
|
||||
return self._current_bytes
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _estimate_frame_bytes(frame: dict) -> int:
|
||||
"""Rough byte estimate for a single frame dictionary."""
|
||||
total = 0
|
||||
for v in frame.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
|
||||
# memory cap is honoured even when frames hold unconverted tensors.
|
||||
total += v.nelement() * v.element_size()
|
||||
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
||||
total += v.nbytes
|
||||
elif isinstance(v, (int, float)):
|
||||
total += 8
|
||||
elif isinstance(v, (str, bytes)):
|
||||
total += len(v)
|
||||
return max(total, 1) # avoid zero-size frames
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Thread-safe robot wrapper for concurrent observation/action access."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
class ThreadSafeRobot:
|
||||
"""Lock-protected wrapper around a :class:`Robot` for use with background threads.
|
||||
|
||||
When RTC inference runs in a background thread while the main loop
|
||||
executes actions, both threads may access the robot concurrently.
|
||||
This wrapper serialises ``get_observation`` and ``send_action`` calls.
|
||||
|
||||
Read-only properties are proxied without the lock since they don't
|
||||
mutate hardware state.
|
||||
"""
|
||||
|
||||
def __init__(self, robot: Robot) -> None:
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
|
||||
# -- Lock-protected I/O --------------------------------------------------
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict[str, Any] | Any) -> Any:
|
||||
with self._lock:
|
||||
return self._robot.send_action(action)
|
||||
|
||||
# -- Read-only proxies (no lock needed) -----------------------------------
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self._robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self._robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self._robot.robot_type
|
||||
|
||||
@property
|
||||
def cameras(self):
|
||||
return getattr(self._robot, "cameras", {})
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._robot.is_connected
|
||||
|
||||
@property
|
||||
def inner(self) -> Robot:
|
||||
"""Access the underlying robot (e.g. for connect/disconnect)."""
|
||||
return self._robot
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategies — public API re-exports."""
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategy",
|
||||
"DAggerEvents",
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
"estimate_max_episode_seconds",
|
||||
"safe_push_to_hub",
|
||||
"send_next_action",
|
||||
]
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base rollout strategy: autonomous policy execution with no data recording."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStrategy(RolloutStrategy):
|
||||
"""Autonomous policy rollout with no data recording.
|
||||
|
||||
All actions flow through the ``robot_action_processor`` pipeline
|
||||
before reaching the robot.
|
||||
"""
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine."""
|
||||
self._init_engine(ctx)
|
||||
logger.info("Base strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous control loop until shutdown or duration expires."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
interpolator = self._interpolator
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
engine.resume()
|
||||
logger.info("Base strategy control loop started")
|
||||
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Disconnect hardware and stop inference."""
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Base strategy teardown complete")
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategy ABC and shared action-dispatch helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..inference import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
from ..context import HardwareContext, RolloutContext, RuntimeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
"""Abstract base for rollout execution strategies.
|
||||
|
||||
Each concrete strategy implements a self-contained control loop with
|
||||
its own recording/interaction semantics. Strategies are mutually
|
||||
exclusive — only one runs per session.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Attach the inference engine and action interpolator, then start the backend.
|
||||
|
||||
Creates an :class:`ActionInterpolator` from the config's
|
||||
``interpolation_multiplier`` and starts the inference engine.
|
||||
Call this from ``setup()`` so strategies share identical
|
||||
initialisation without duplicating code.
|
||||
"""
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
|
||||
self._engine = ctx.policy.inference
|
||||
logger.info("Starting inference engine...")
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
logger.info("Inference engine started")
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
Returns ``True`` if the caller should ``continue`` (still warming
|
||||
up). On the first post-warmup iteration the engine and
|
||||
interpolator are reset so stale warmup state is discarded.
|
||||
"""
|
||||
engine = self._engine
|
||||
interpolator = self._interpolator
|
||||
if not use_torch_compile:
|
||||
return False
|
||||
if not engine.ready:
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
return True
|
||||
if not self._warmup_flushed:
|
||||
logger.info("Warmup complete — flushing stale state and resuming engine")
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
self._warmup_flushed = True
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, hw: HardwareContext) -> None:
|
||||
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
logger.info("Stopping inference engine...")
|
||||
self._engine.stop()
|
||||
robot = hw.robot_wrapper.inner
|
||||
if robot.is_connected:
|
||||
if hw.initial_position:
|
||||
logger.info("Returning robot to initial position before shutdown...")
|
||||
self._return_to_initial_position(hw)
|
||||
logger.info("Disconnecting robot...")
|
||||
robot.disconnect()
|
||||
teleop = hw.teleop
|
||||
if teleop is not None and teleop.is_connected:
|
||||
logger.info("Disconnecting teleoperator...")
|
||||
teleop.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
|
||||
"""Smoothly interpolate the robot back to its initial position."""
|
||||
robot = hw.robot_wrapper
|
||||
target = hw.initial_position
|
||||
try:
|
||||
current_obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in current_obs.items() if k in target}
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
for step in range(1, steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current_pos:
|
||||
interp[k] = current_pos[k] * (1 - t) + target[k] * t
|
||||
robot.send_action(interp)
|
||||
precise_sleep(1 / fps)
|
||||
except Exception as e:
|
||||
logger.warning("Could not return to initial position: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _log_telemetry(
|
||||
obs_processed: dict | None,
|
||||
action_dict: dict | None,
|
||||
runtime_ctx: RuntimeContext,
|
||||
) -> None:
|
||||
"""Log observation/action telemetry to Rerun if display_data is enabled."""
|
||||
cfg = runtime_ctx.cfg
|
||||
if not cfg.display_data:
|
||||
return
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=action_dict,
|
||||
compress_images=cfg.display_compressed_images,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Cleanup: save dataset, stop threads, disconnect hardware."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
|
||||
"""Push dataset to hub, skipping if no episodes have been saved.
|
||||
|
||||
Returns ``True`` if the push was attempted, ``False`` if skipped.
|
||||
"""
|
||||
if dataset.num_episodes == 0:
|
||||
logger.warning("No episodes saved — skipping push to hub")
|
||||
return False
|
||||
dataset.push_to_hub(tags=tags, private=private)
|
||||
return True
|
||||
|
||||
|
||||
def estimate_max_episode_seconds(
|
||||
dataset_features: dict,
|
||||
fps: float,
|
||||
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
) -> float:
|
||||
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
|
||||
|
||||
Each camera produces its own video file, so the episode duration is
|
||||
driven by the **slowest** camera to fill ``target_size_mb`` — i.e.
|
||||
the one with the fewest pixels per frame (lowest bitrate).
|
||||
|
||||
Uses a deliberately **low** bits-per-pixel estimate so the computed
|
||||
duration is *longer* than reality. By the time the timer fires the
|
||||
actual video file is guaranteed to have crossed the target size,
|
||||
which aligns episode boundaries with the dataset's video-file
|
||||
chunking — each ``push_to_hub`` uploads complete files rather than
|
||||
re-uploading a still-growing one.
|
||||
|
||||
The estimate ignores codec-specific settings (CRF, preset) on purpose:
|
||||
we only need a rough lower bound on bitrate, not a precise prediction.
|
||||
|
||||
Falls back to 600 s (10 min) when no video features are present.
|
||||
"""
|
||||
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
|
||||
# robot footage (real-world is typically 0.1 – 0.3 bpp). Under-
|
||||
# estimating the bitrate over-estimates the time → the episode will be
|
||||
# *larger* than target_size_mb when we save, which is what we want.
|
||||
conservative_bpp = 0.1
|
||||
|
||||
# Collect per-camera pixel counts — each camera has its own video file.
|
||||
camera_pixels = []
|
||||
for feat in dataset_features.values():
|
||||
if feat.get("dtype") == "video":
|
||||
shape = feat.get("shape", ())
|
||||
|
||||
# Assuming shape could be (C, H, W) or (T, C, H, W)
|
||||
# We want to extract the spatial dimensions.
|
||||
if len(shape) >= 3:
|
||||
h, w = shape[-2], shape[-1]
|
||||
pixels = h * w
|
||||
if pixels > 0:
|
||||
camera_pixels.append(pixels)
|
||||
|
||||
if not camera_pixels:
|
||||
return 600.0
|
||||
|
||||
# Use the smallest camera: it produces the lowest bitrate and therefore
|
||||
# takes the longest to reach the target — the conservative choice.
|
||||
min_pixels = min(camera_pixels)
|
||||
bits_per_frame = min_pixels * conservative_bpp
|
||||
bytes_per_second = (bits_per_frame * fps) / 8
|
||||
|
||||
# Guard against division by zero just in case
|
||||
if bytes_per_second <= 0:
|
||||
return 600.0
|
||||
|
||||
return (target_size_mb * 1024 * 1024) / bytes_per_second
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared action-dispatch helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_next_action(
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
) -> dict | None:
|
||||
"""Dispatch the next action to the robot.
|
||||
|
||||
Pulls the next action tensor from the inference engine, feeds the
|
||||
interpolator, and sends the interpolated action through the
|
||||
``robot_action_processor`` to the robot. Works identically for
|
||||
sync and async backends — the rollout strategy never needs to branch.
|
||||
|
||||
Returns the action dict that was sent, or ``None`` if no action was
|
||||
ready (e.g. empty async queue, interpolator not yet primed).
|
||||
"""
|
||||
engine = ctx.policy.inference
|
||||
features = ctx.data.dataset_features
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action(obs_frame)
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is None:
|
||||
return None
|
||||
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
|
||||
ctx.hardware.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
@@ -0,0 +1,733 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""DAgger rollout strategy: Human-in-the-Loop data collection.
|
||||
|
||||
Implements the RaC paradigm (Recovery and Correction) for interactive
|
||||
imitation learning. Alternates between autonomous policy execution and
|
||||
human intervention via teleoperator.
|
||||
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
the ``input_device`` config field. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
|
||||
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
|
||||
3. **upload** — Push dataset to hub on demand (corrections-only mode).
|
||||
ESC (keyboard only) — Stop session.
|
||||
|
||||
Recording Modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Both autonomous and correction
|
||||
frames are recorded; corrections tagged ``intervention=True``.
|
||||
``record_autonomous=False``: Only correction windows are recorded.
|
||||
Each correction (start to stop) becomes one episode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.pedal import start_pedal_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DAggerPhase(enum.Enum):
|
||||
"""Observable phases of a DAgger episode."""
|
||||
|
||||
AUTONOMOUS = "autonomous" # Policy driving
|
||||
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
|
||||
CORRECTING = "correcting" # Human driving via teleop, recording interventions
|
||||
|
||||
|
||||
# Valid (current_phase, event) -> next_phase
|
||||
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
|
||||
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
|
||||
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
|
||||
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
|
||||
}
|
||||
|
||||
|
||||
class DAggerEvents:
|
||||
"""Thread-safe container for DAgger input device events.
|
||||
|
||||
The keyboard/pedal threads write transition requests; the main loop
|
||||
consumes them.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition: str | None = None
|
||||
|
||||
# Session-level flags
|
||||
self.stop_recording = Event()
|
||||
self.upload_requested = Event()
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@property
|
||||
def phase(self) -> DAggerPhase:
|
||||
"""Current phase of the DAgger state machine."""
|
||||
with self._lock:
|
||||
return self._phase
|
||||
|
||||
@phase.setter
|
||||
def phase(self, value: DAggerPhase) -> None:
|
||||
with self._lock:
|
||||
self._phase = value
|
||||
|
||||
def request_transition(self, event: str) -> None:
|
||||
"""Request a phase transition (called from keyboard/pedal threads).
|
||||
|
||||
Only enqueues the request if it corresponds to a valid transition
|
||||
from the current phase, preventing impossible state changes.
|
||||
"""
|
||||
with self._lock:
|
||||
if (self._phase, event) in _DAGGER_TRANSITIONS:
|
||||
self._pending_transition = event
|
||||
|
||||
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
|
||||
"""Consume a pending transition (called from main loop)."""
|
||||
with self._lock:
|
||||
if self._pending_transition is None:
|
||||
return None
|
||||
key = (self._phase, self._pending_transition)
|
||||
self._pending_transition = None
|
||||
new_phase = _DAGGER_TRANSITIONS.get(key)
|
||||
if new_phase is None:
|
||||
return None
|
||||
old_phase = self._phase
|
||||
self._phase = new_phase
|
||||
return old_phase, new_phase
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all transient state for a fresh session."""
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move teleop to target position via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support motor control methods
|
||||
(``enable_torque``, ``write_goal_positions``, ``get_action``).
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
"""Initialise keyboard listener with DAgger 3-key controls.
|
||||
|
||||
Returns the pynput Listener (or ``None`` in headless mode or when
|
||||
pynput is unavailable).
|
||||
"""
|
||||
if not PYNPUT_AVAILABLE or is_headless():
|
||||
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
|
||||
return None
|
||||
|
||||
# Map config key names to pynput Key objects for special keys
|
||||
special_keys = {
|
||||
"space": keyboard.Key.space,
|
||||
"tab": keyboard.Key.tab,
|
||||
"enter": keyboard.Key.enter,
|
||||
}
|
||||
|
||||
def _resolve_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to a config-comparable string."""
|
||||
if key == keyboard.Key.esc:
|
||||
return "esc"
|
||||
for name, pynput_key in special_keys.items():
|
||||
if key == pynput_key:
|
||||
return name
|
||||
if hasattr(key, "char") and key.char:
|
||||
return key.char
|
||||
return None
|
||||
|
||||
# Build mapping: resolved key string -> DAgger event name
|
||||
key_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
resolved = _resolve_key(key)
|
||||
if resolved is None:
|
||||
return
|
||||
if resolved == "esc":
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording.set()
|
||||
return
|
||||
if resolved in key_to_event:
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
)
|
||||
return listener
|
||||
|
||||
|
||||
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
"""Initialise foot pedal listener with DAgger 3-pedal controls.
|
||||
|
||||
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
|
||||
"""
|
||||
code_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(code: str) -> None:
|
||||
if code in code_to_event:
|
||||
events.request_transition(code_to_event[code])
|
||||
if code == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
|
||||
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
|
||||
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger Strategy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DAggerStrategy(RolloutStrategy):
|
||||
"""Human-in-the-Loop data collection with intervention tagging.
|
||||
|
||||
State machine::
|
||||
|
||||
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
|
||||
--(key1)--> AUTONOMOUS
|
||||
|
||||
Recording modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Intervention frames tagged True.
|
||||
``record_autonomous=False``: Only correction windows recorded.
|
||||
Each correction = one episode. Upload on demand via key3.
|
||||
"""
|
||||
|
||||
config: DAggerStrategyConfig
|
||||
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._pedal_thread = None
|
||||
self._events = DAggerEvents()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine and input device listener."""
|
||||
self._init_engine(ctx)
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
|
||||
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
self._episode_duration_s = estimate_max_episode_seconds(
|
||||
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
|
||||
)
|
||||
|
||||
if self.config.input_device == "keyboard":
|
||||
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
|
||||
else:
|
||||
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
|
||||
|
||||
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
|
||||
logger.info(
|
||||
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
|
||||
self.config.input_device,
|
||||
self.config.num_episodes,
|
||||
record_mode,
|
||||
self._episode_duration_s,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run DAgger episodes with human-in-the-loop intervention."""
|
||||
if self.config.record_autonomous:
|
||||
self._run_continuous(ctx)
|
||||
else:
|
||||
self._run_corrections_only(ctx)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping DAgger recording")
|
||||
log_say("Stopping DAgger recording", play_sounds)
|
||||
|
||||
if self._listener is not None and not is_headless():
|
||||
logger.info("Stopping keyboard listener")
|
||||
self._listener.stop()
|
||||
|
||||
# Flush any queued/running push cleanly
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_continuous(self, ctx: RolloutContext) -> None:
|
||||
"""Sentry-like continuous recording with intervention tagging.
|
||||
|
||||
Episodes are auto-rotated every ``episode_time_s`` seconds and
|
||||
uploaded in the background every ``upload_every_n_episodes`` episodes.
|
||||
Both autonomous and correction frames are recorded; corrections are
|
||||
tagged with ``intervention=True``.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
start_time = time.perf_counter()
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
episode_duration_s = self._episode_duration_s
|
||||
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTING: human teleop control ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# Episode rotation derived from video file-size target.
|
||||
# Do NOT save mid-correction — wait for the correction
|
||||
# to finish so the episode boundary is clean.
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Episode saved (total: %d, elapsed: %.1fs)",
|
||||
dataset.num_episodes,
|
||||
elapsed,
|
||||
)
|
||||
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Corrections-only mode (record_autonomous=False)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_corrections_only(self, ctx: RolloutContext) -> None:
|
||||
"""Record only human correction windows. Each correction = one episode.
|
||||
|
||||
The policy runs autonomously without recording. When the user
|
||||
pauses and starts a correction, frames are recorded with
|
||||
``intervention=True``. Stopping the correction saves the episode.
|
||||
The dataset can be uploaded on demand via the upload key/pedal.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
recorded = 0
|
||||
logger.info(
|
||||
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while (
|
||||
recorded < self.config.num_episodes
|
||||
and not events.stop_recording.is_set()
|
||||
and not ctx.runtime.shutdown_event.is_set()
|
||||
):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Correction %d/%d saved",
|
||||
recorded,
|
||||
self.config.num_episodes,
|
||||
)
|
||||
log_say(f"Correction {recorded} saved", play_sounds)
|
||||
|
||||
# On-demand upload
|
||||
if events.upload_requested.is_set():
|
||||
events.upload_requested.clear()
|
||||
logger.info("Upload requested by user")
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
|
||||
# --- CORRECTING: human teleop control + recording ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
dataset.add_frame(
|
||||
{
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control (no recording) ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State-machine transition side-effects
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_transition(
|
||||
old_phase: DAggerPhase,
|
||||
new_phase: DAggerPhase,
|
||||
engine,
|
||||
interpolator,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
) -> None:
|
||||
"""Execute side-effects for a validated phase transition."""
|
||||
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
|
||||
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
||||
logger.info("Pausing engine — robot holds position")
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
_robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode — human teleop control")
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("Resuming autonomous mode — resetting engine and interpolator")
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
engine.resume()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background push (shared by both modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor.
|
||||
|
||||
The executor's max_workers=1 guarantees at most one push runs at
|
||||
a time; submitted tasks are queued rather than dropped. Pushes
|
||||
are blocked while the operator is mid-correction to avoid
|
||||
uploading a partially-recorded episode.
|
||||
"""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._events.phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Skipping push — correction in progress")
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
with self._episode_lock:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Strategy factory: config type-name → strategy class dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout import RolloutStrategyConfig
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
"""Instantiate the appropriate strategy from a config object.
|
||||
|
||||
Dispatches on ``config.type`` (the name registered via
|
||||
``draccus.ChoiceRegistry``).
|
||||
"""
|
||||
if config.type == "base":
|
||||
return BaseStrategy(config)
|
||||
if config.type == "sentry":
|
||||
return SentryStrategy(config)
|
||||
if config.type == "highlight":
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
@@ -0,0 +1,277 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Highlight Reel strategy: on-demand recording via ring buffer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import HighlightStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HighlightStrategy(RolloutStrategy):
|
||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||
|
||||
The robot runs autonomously while a memory-bounded ring buffer
|
||||
captures continuous telemetry. When the user presses the save key:
|
||||
|
||||
1. The ring buffer is flushed to the dataset (last *Z* seconds).
|
||||
2. Live recording continues until the save key is pressed again.
|
||||
3. The episode is saved and the ring buffer resumes capturing.
|
||||
|
||||
Requires ``streaming_encoding=True`` (enforced in config validation)
|
||||
so that ``dataset.add_frame`` is a non-blocking queue put — draining
|
||||
900 frames stays sub-ms per frame.
|
||||
"""
|
||||
|
||||
config: HighlightStrategyConfig
|
||||
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
self._recording_live = ThreadingEvent()
|
||||
self._push_requested = ThreadingEvent()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine, ring buffer, and keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
|
||||
self._ring = RolloutRingBuffer(
|
||||
max_seconds=self.config.ring_buffer_seconds,
|
||||
max_memory_mb=self.config.ring_buffer_max_memory_mb,
|
||||
fps=ctx.runtime.cfg.fps,
|
||||
)
|
||||
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
|
||||
logger.info(
|
||||
"Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)",
|
||||
self.config.ring_buffer_seconds,
|
||||
self.config.ring_buffer_max_memory_mb,
|
||||
)
|
||||
self._setup_keyboard(ctx.runtime.shutdown_event)
|
||||
logger.info(
|
||||
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
|
||||
self.config.ring_buffer_seconds,
|
||||
self.config.save_key,
|
||||
self.config.push_key,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous loop, buffering frames and recording on demand."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
dataset = ctx.data.dataset
|
||||
ring = self._ring
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
engine.resume()
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
start_time = time.perf_counter()
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
|
||||
# NOTE: ``is_set()`` then ``clear()`` is not atomic
|
||||
# against the keyboard thread setting the flag again
|
||||
# in between — but that is benign: we lose at most one
|
||||
# toggle, processed on the next iteration. The
|
||||
# ``_recording_live`` branch below is reached in the
|
||||
# SAME iteration after ``clear()`` runs, so a frame
|
||||
# finalised by ``save_episode()`` is never re-added to
|
||||
# the next episode.
|
||||
if self._save_requested.is_set():
|
||||
self._save_requested.clear()
|
||||
if not self._recording_live.is_set():
|
||||
logger.info(
|
||||
"Flushing ring buffer (%d frames) + starting live recording",
|
||||
len(ring),
|
||||
)
|
||||
for buffered_frame in ring.drain():
|
||||
dataset.add_frame(buffered_frame)
|
||||
self._recording_live.set()
|
||||
else:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
log_say(
|
||||
f"Episode {dataset.num_episodes} saved",
|
||||
play_sounds,
|
||||
)
|
||||
self._recording_live.clear()
|
||||
|
||||
if self._push_requested.is_set():
|
||||
self._push_requested.clear()
|
||||
logger.info("Push requested by user")
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
if self._recording_live.is_set():
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
ring.append(frame)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("Highlight control loop ended")
|
||||
if self._recording_live.is_set():
|
||||
logger.info("Saving in-progress live episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping highlight recording")
|
||||
log_say("Stopping highlight recording", play_sounds)
|
||||
|
||||
if self._listener is not None:
|
||||
logger.info("Stopping keyboard listener")
|
||||
self._listener.stop()
|
||||
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
"""Set up keyboard listener for save and push keys."""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — highlight keys unavailable")
|
||||
return
|
||||
|
||||
try:
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char == save_key:
|
||||
self._save_requested.set()
|
||||
elif hasattr(key, "char") and key.char == push_key:
|
||||
self._push_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested.clear()
|
||||
shutdown_event.set()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
except ImportError:
|
||||
logger.warning("pynput not available — keyboard listener disabled")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor."""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -0,0 +1,225 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sentry rollout strategy: continuous autonomous recording with auto-upload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import SentryStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentryStrategy(RolloutStrategy):
|
||||
"""Continuous autonomous rollout with always-on recording.
|
||||
|
||||
Episode duration is derived from camera resolution, FPS, and
|
||||
``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode
|
||||
produces a video file that has crossed the chunk-size boundary.
|
||||
This keeps ``push_to_hub`` efficient — it uploads complete video
|
||||
files rather than re-uploading a still-growing one.
|
||||
|
||||
The dataset is pushed to the Hub via a bounded single-worker executor
|
||||
so no push is ever silently dropped and exactly one push runs at a
|
||||
time.
|
||||
|
||||
Policy state (hidden state, RTC queue) intentionally persists across
|
||||
episode boundaries — Sentry slices one continuous rollout, the robot
|
||||
does not reset between slices.
|
||||
|
||||
Requires ``streaming_encoding=True`` (enforced in config validation)
|
||||
to prevent disk I/O from blocking the control loop.
|
||||
"""
|
||||
|
||||
config: SentryStrategyConfig
|
||||
|
||||
def __init__(self, config: SentryStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine and background push executor."""
|
||||
self._init_engine(ctx)
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push")
|
||||
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
self._episode_duration_s = estimate_max_episode_seconds(
|
||||
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
|
||||
)
|
||||
logger.info(
|
||||
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
|
||||
self._episode_duration_s,
|
||||
self.config.upload_every_n_episodes,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the continuous recording loop with automatic episode rotation."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
dataset = ctx.data.dataset
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
engine.resume()
|
||||
play_sounds = cfg.play_sounds
|
||||
episode_duration_s = self._episode_duration_s
|
||||
|
||||
start_time = time.perf_counter()
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
# ``add_frame`` writes to the in-progress episode buffer; the
|
||||
# background pusher only ever touches *finalised* episode
|
||||
# artifacts on disk. The two operate on disjoint state, so
|
||||
# ``add_frame`` does not need ``_episode_lock``.
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Episode rotation derived from video file-size target.
|
||||
# The duration is a conservative estimate so the actual
|
||||
# video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now,
|
||||
# keeping push_to_hub efficient (uploads complete files).
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s:
|
||||
# ``save_episode`` finalises the in-progress episode and
|
||||
# flushes it to disk; ``_episode_lock`` serialises this with
|
||||
# ``push_to_hub`` (run in the background executor) so the
|
||||
# pusher never reads a half-written episode.
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Episode saved (total: %d, elapsed: %.1fs)",
|
||||
dataset.num_episodes,
|
||||
elapsed,
|
||||
)
|
||||
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("Sentry control loop ended — saving final episode")
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Flush pending pushes, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping sentry recording")
|
||||
log_say("Stopping sentry recording", play_sounds)
|
||||
|
||||
# Flush any queued/running push cleanly.
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor.
|
||||
|
||||
The executor's max_workers=1 guarantees at most one push runs at
|
||||
a time; submitted tasks are queued rather than dropped.
|
||||
"""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
with self._episode_lock:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -286,7 +286,7 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
if len(set(num_eps_per_cam)) != 1:
|
||||
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||
|
||||
episods_metadata = []
|
||||
episodes_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
|
||||
@@ -299,9 +299,9 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
|
||||
ep_dict = {}
|
||||
for cam_idx in range(num_cameras):
|
||||
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||
episods_metadata.append(ep_dict)
|
||||
episodes_metadata.append(ep_dict)
|
||||
|
||||
return episods_metadata
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int):
|
||||
|
||||
@@ -13,70 +13,62 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
|
||||
Records a dataset via teleoperation. This is a pure data-collection
|
||||
tool — no policy inference. For deploying trained policies, use
|
||||
``lerobot-rollout`` instead.
|
||||
|
||||
Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras)
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
lerobot-record \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \\
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--robot.id=black \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \\
|
||||
--teleop.id=blue \\
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \\
|
||||
--dataset.num_episodes=2 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2 \\
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
# --teleop.id=blue \
|
||||
# <- Policy optional if you want to record with a policy \
|
||||
# --policy.path=${HF_USER}/my_policy \
|
||||
```
|
||||
|
||||
Example recording with bimanual so100:
|
||||
```shell
|
||||
lerobot-record \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \
|
||||
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \
|
||||
--robot.id=bimanual_follower \
|
||||
lerobot-record \\
|
||||
--robot.type=bi_so_follower \\
|
||||
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \\
|
||||
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \\
|
||||
--robot.id=bimanual_follower \\
|
||||
--robot.left_arm_config.cameras='{
|
||||
wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 3, "width": 640, "height": 480, "fps": 30},
|
||||
}' --robot.right_arm_config.cameras='{
|
||||
wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
front: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--teleop.type=bi_so_leader \
|
||||
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \
|
||||
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \
|
||||
--teleop.id=bimanual_leader \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
}' \\
|
||||
--teleop.type=bi_so_leader \\
|
||||
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \\
|
||||
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \\
|
||||
--teleop.id=bimanual_leader \\
|
||||
--display_data=true \\
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \\
|
||||
--dataset.num_episodes=25 \\
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
@@ -86,11 +78,10 @@ from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.common.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
predict_action,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
@@ -98,21 +89,11 @@ from lerobot.datasets import (
|
||||
create_initial_features,
|
||||
safe_stop_image_writer,
|
||||
)
|
||||
from lerobot.policies import (
|
||||
ActionInterpolator,
|
||||
PreTrainedPolicy,
|
||||
make_policy,
|
||||
make_pre_post_processors,
|
||||
make_robot_action,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
@@ -146,7 +127,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
)
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
@@ -157,71 +137,12 @@ from lerobot.utils.utils import (
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Whether to control the robot with a teleoperator
|
||||
# Teleoperator to control the robot (required)
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
# Whether to control the robot with a policy
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Display data on a remote Rerun server
|
||||
@@ -234,27 +155,14 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
if self.teleop is None:
|
||||
raise ValueError(
|
||||
"A teleoperator is required for recording. "
|
||||
"Use --teleop.type=... to specify one. "
|
||||
"For policy-based deployment, use lerobot-rollout instead."
|
||||
)
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
@@ -264,18 +172,14 @@ class RecordConfig:
|
||||
V
|
||||
[ robot_observation_processor ] ---> processed_obs
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> processed_teleop_action '---> processed_policy_action
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
[ Teleoperator ]
|
||||
|
|
||||
| [teleop.get_action] -> raw_action
|
||||
| |
|
||||
| V
|
||||
| [teleop_action_processor]
|
||||
| |
|
||||
'---> processed_teleop_action
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
@@ -303,13 +207,9 @@ def record_loop(
|
||||
], # runs after robot
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -340,21 +240,7 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
# Pre-compute action key order outside the hot loop — it won't change mid-episode.
|
||||
action_keys = sorted(robot.action_features) if use_interpolation else []
|
||||
control_interval = 1 / fps
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
@@ -372,63 +258,11 @@ def record_loop(
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Track whether this iteration should be recorded to the dataset.
|
||||
# Interpolated-only iterations send actions to the robot but don't record frames,
|
||||
# keeping the dataset at the original fps while the robot moves at the higher rate.
|
||||
is_record_frame = True
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
ran_inference = False
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
ran_inference = True
|
||||
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
action_values = robot_action_to_send
|
||||
else:
|
||||
continue
|
||||
|
||||
is_record_frame = ran_inference
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
action_values = robot_action_to_send
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
# Get action from teleop
|
||||
if isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
if robot.name == "unitree_g1":
|
||||
teleop.send_feedback(obs)
|
||||
@@ -438,7 +272,7 @@ def record_loop(
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
elif isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
@@ -451,7 +285,7 @@ def record_loop(
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"No teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
@@ -463,8 +297,8 @@ def record_loop(
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset (only on real policy frames, not interpolated-only iterations)
|
||||
if dataset is not None and is_record_frame:
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
@@ -488,7 +322,12 @@ def record_loop(
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
def record(
|
||||
cfg: RecordConfig,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
) -> LeRobotDataset:
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
if cfg.display_data:
|
||||
@@ -502,7 +341,16 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
# Fall back to identity pipelines when the caller doesn't supply processors.
|
||||
if (
|
||||
teleop_action_processor is None
|
||||
or robot_action_processor is None
|
||||
or robot_observation_processor is None
|
||||
):
|
||||
_t, _r, _o = make_default_processors()
|
||||
teleop_action_processor = teleop_action_processor or _t
|
||||
robot_action_processor = robot_action_processor or _r
|
||||
robot_observation_processor = robot_observation_processor or _o
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
@@ -540,8 +388,12 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
|
||||
# Reject eval_ prefix — for policy evaluation use lerobot-rollout
|
||||
if cfg.dataset.repo_id.startswith("eval_"):
|
||||
raise ValueError(
|
||||
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
||||
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
||||
)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
@@ -558,26 +410,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
teleop.connect()
|
||||
@@ -601,14 +433,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
@@ -656,7 +484,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
if dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
else:
|
||||
logging.warning("No episodes saved — skipping push to hub")
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy deployment engine with pluggable rollout strategies.
|
||||
|
||||
``lerobot-rollout`` is the single CLI for running trained policies on
|
||||
real robots.
|
||||
|
||||
Strategies
|
||||
----------
|
||||
--strategy.type=base Autonomous rollout, no recording
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
--inference.type=sync One policy call per control tick (default)
|
||||
--inference.type=rtc Real-Time Chunking for slow VLA models
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
::
|
||||
|
||||
# Base mode — quick evaluation with sync inference
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--task="pick up cube" --duration=30
|
||||
|
||||
# Base mode — RTC inference for slow VLAs (Pi0, Pi0.5, SmolVLA)
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=rtc \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--inference.rtc.max_guidance_weight=10.0 \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--task="pick up cube" --duration=60
|
||||
|
||||
# Sentry mode — continuous recording with periodic upload
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--strategy.upload_every_n_episodes=5 \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=rtc \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/sentry-data \\
|
||||
--dataset.single_task="patrol" --duration=3600
|
||||
|
||||
# Highlight mode — ring buffer, press 's' to save, 'h' to push
|
||||
lerobot-rollout \\
|
||||
--strategy.type=highlight \\
|
||||
--strategy.ring_buffer_seconds=30 \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/highlight-data \\
|
||||
--dataset.single_task="pick up cube"
|
||||
|
||||
# DAgger mode — human-in-the-loop corrections only
|
||||
lerobot-rollout \\
|
||||
--strategy.type=dagger \\
|
||||
--strategy.num_episodes=20 \\
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \\
|
||||
--robot.type=bi_openarm_follower \\
|
||||
--teleop.type=openarm_mini \\
|
||||
--dataset.repo_id=user/hil-data \\
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
|
||||
# DAgger mode — continuous recording with RTC inference
|
||||
lerobot-rollout \\
|
||||
--strategy.type=dagger \\
|
||||
--strategy.record_autonomous=true \\
|
||||
--strategy.num_episodes=50 \\
|
||||
--inference.type=rtc \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--policy.path=user/my_pi0_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so101_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/dagger-rtc-data \\
|
||||
--dataset.single_task="Grasp the block"
|
||||
|
||||
# With Rerun visualization and torch.compile
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--task="pick up cube" --duration=60 \\
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/sentry-data \\
|
||||
--dataset.single_task="patrol" \\
|
||||
--resume=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.rollout import RolloutConfig, build_rollout_context, create_strategy
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rollout(cfg: RolloutConfig):
|
||||
"""Main entry point for policy deployment."""
|
||||
init_logging()
|
||||
|
||||
if cfg.display_data:
|
||||
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
|
||||
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
logger.info("Building rollout context...")
|
||||
ctx = build_rollout_context(cfg, shutdown_event)
|
||||
|
||||
strategy = create_strategy(cfg.strategy)
|
||||
logger.info("Rollout strategy: %s", cfg.strategy.type)
|
||||
logger.info(
|
||||
"Robot: %s | FPS: %.0f | Duration: %s",
|
||||
cfg.robot.type if cfg.robot else "?",
|
||||
cfg.fps,
|
||||
f"{cfg.duration}s" if cfg.duration > 0 else "infinite",
|
||||
)
|
||||
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
logger.info("Rollout setup complete, starting rollout...")
|
||||
strategy.run(ctx)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
logger.info("Rollout finished")
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for ``lerobot-rollout``."""
|
||||
register_third_party_plugins()
|
||||
rollout()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -386,7 +386,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
@@ -433,6 +434,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
|
||||
@@ -15,9 +15,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING or _pygame_available:
|
||||
import pygame
|
||||
else:
|
||||
pygame = None # type: ignore[assignment]
|
||||
|
||||
if TYPE_CHECKING or _hidapi_available:
|
||||
import hid
|
||||
else:
|
||||
hid = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -199,6 +212,7 @@ class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
require_package("pygame", extra="gamepad")
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
@@ -206,8 +220,6 @@ class GamepadController(InputController):
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
@@ -230,8 +242,6 @@ class GamepadController(InputController):
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
@@ -240,8 +250,6 @@ class GamepadController(InputController):
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
@@ -280,8 +288,6 @@ class GamepadController(InputController):
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
@@ -326,6 +332,7 @@ class GamepadControllerHID(InputController):
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
require_package("hidapi", extra="gamepad", import_name="hid")
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
@@ -342,8 +349,6 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
@@ -357,8 +362,6 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
|
||||
@@ -45,7 +45,7 @@ class HomunculusArm(Teleoperator):
|
||||
name = "homunculus_arm"
|
||||
|
||||
def __init__(self, config: HomunculusArmConfig):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
|
||||
@@ -71,7 +71,7 @@ class HomunculusGlove(Teleoperator):
|
||||
name = "homunculus_glove"
|
||||
|
||||
def __init__(self, config: HomunculusGloveConfig):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -56,6 +56,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
name = "keyboard"
|
||||
|
||||
def __init__(self, config: KeyboardTeleopConfig):
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.robot_type = config.type
|
||||
|
||||
@@ -21,14 +21,24 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _hebi_available, _teleop_available, require_package
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
if TYPE_CHECKING or _hebi_available:
|
||||
import hebi
|
||||
else:
|
||||
hebi = None
|
||||
|
||||
if TYPE_CHECKING or _teleop_available:
|
||||
from teleop import Teleop
|
||||
else:
|
||||
Teleop = None
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_phone import PhoneConfig, PhoneOS
|
||||
|
||||
@@ -74,6 +84,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
name = "ios_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
require_package("hebi-py", extra="phone", import_name="hebi")
|
||||
require_package("teleop", extra="phone")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
@@ -213,6 +225,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
name = "android_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
require_package("hebi-py", extra="phone", import_name="hebi")
|
||||
require_package("teleop", extra="phone")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._teleop = None
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
from reachy2_sdk import ReachySDK
|
||||
@@ -84,6 +84,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
name = "reachy2_specific"
|
||||
|
||||
def __init__(self, config: Reachy2TeleoperatorConfig):
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -34,7 +34,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _serial_available
|
||||
from lerobot.utils.import_utils import _serial_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _serial_available:
|
||||
import serial
|
||||
@@ -156,6 +156,7 @@ def run_exo_calibration(
|
||||
"""
|
||||
Run interactive calibration for an exoskeleton arm.
|
||||
"""
|
||||
require_package("pyserial", extra="unitree_g1", import_name="serial")
|
||||
try:
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -76,7 +76,7 @@ class ExoskeletonArm:
|
||||
calibration: ExoskeletonCalibration | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="unitree_g1", import_name="serial")
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -115,6 +115,12 @@ _feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="
|
||||
_reachy2_sdk_available = is_package_available("reachy2_sdk")
|
||||
_can_available = is_package_available("python-can", "can")
|
||||
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
|
||||
_pyrealsense2_available = is_package_available("pyrealsense2")
|
||||
_zmq_available = is_package_available("pyzmq", import_name="zmq")
|
||||
_hebi_available = is_package_available("hebi-py", import_name="hebi")
|
||||
_teleop_available = is_package_available("teleop")
|
||||
_placo_available = is_package_available("placo")
|
||||
_hidapi_available = is_package_available("hidapi", import_name="hid")
|
||||
|
||||
# Data / serialization
|
||||
_pandas_available = is_package_available("pandas")
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generic foot pedal listener using evdev.
|
||||
|
||||
Callers supply a callback receiving the pressed key code (e.g. ``"KEY_A"``)
|
||||
and an optional device path. The listener runs in a daemon thread and
|
||||
silently no-ops when :mod:`evdev` is not installed or the device is
|
||||
unavailable. Strategy-specific key mapping logic lives in the caller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
|
||||
|
||||
def start_pedal_listener(
|
||||
on_press: Callable[[str], None],
|
||||
device_path: str = DEFAULT_PEDAL_DEVICE,
|
||||
) -> threading.Thread | None:
|
||||
"""Spawn a daemon thread that forwards pedal key-press codes to ``on_press``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
on_press:
|
||||
Callback invoked with the pressed key code string (e.g. ``"KEY_A"``)
|
||||
on each pedal press event. The callback runs in the listener thread
|
||||
and must be thread-safe.
|
||||
device_path:
|
||||
Linux input device path (e.g. ``/dev/input/by-id/...``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
The started daemon :class:`threading.Thread`, or ``None`` when
|
||||
:mod:`evdev` is not installed (optional dependency; silent no-op).
|
||||
"""
|
||||
try:
|
||||
from evdev import InputDevice, categorize, ecodes
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
def pedal_reader() -> None:
|
||||
try:
|
||||
dev = InputDevice(device_path)
|
||||
logger.info("Pedal connected: %s", dev.name)
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
if key.keystate != 1: # only key-down events
|
||||
continue
|
||||
try:
|
||||
on_press(code)
|
||||
except Exception as cb_err: # pragma: no cover - defensive
|
||||
logger.warning("Pedal callback error: %s", cb_err)
|
||||
except (FileNotFoundError, PermissionError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Pedal error: %s", e)
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True, name="PedalListener")
|
||||
thread.start()
|
||||
return thread
|
||||
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.reader.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
|
||||
@@ -454,6 +454,35 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
|
||||
def test_cleanup_interrupted_episode_removes_image_temp_dirs(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify interrupted episode cleanup removes temporary image directories for both image and video features."""
|
||||
features = {
|
||||
"image": {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||
"video": {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "interrupted", features=features, streaming_encoding=False
|
||||
)
|
||||
# Add one frame without saving episode simulating an interruption
|
||||
ds.add_frame(
|
||||
{
|
||||
"image": np.random.rand(*DUMMY_CHW),
|
||||
"video": np.random.rand(*DUMMY_HWC),
|
||||
"task": "Dummy task",
|
||||
}
|
||||
)
|
||||
img_dir = ds.writer._get_image_file_dir(0, "image")
|
||||
vid_img_dir = ds.writer._get_image_file_dir(0, "video")
|
||||
# Precondition: both temp dirs exist after add_frame.
|
||||
assert img_dir.exists()
|
||||
assert vid_img_dir.exists()
|
||||
|
||||
ds.writer.cleanup_interrupted_episode(episode_index=0)
|
||||
|
||||
assert not img_dir.exists(), "image temp dir leaked after cleanup_interrupted_episode"
|
||||
assert not vid_img_dir.exists(), "video temp dir leaked after cleanup_interrupted_episode"
|
||||
|
||||
|
||||
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||
image_key = "image"
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
|
||||
@@ -18,6 +18,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
if not is_package_available("reachy2_sdk"):
|
||||
pytest.skip("reachy2_sdk not available", allow_module_level=True)
|
||||
|
||||
from lerobot.teleoperators.reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
|
||||
@@ -24,10 +24,6 @@ def lerobot_train(args):
|
||||
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
|
||||
|
||||
|
||||
def lerobot_record(args):
|
||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||
|
||||
|
||||
def resolve_model_id_for_peft_training(policy_type):
|
||||
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
|
||||
if policy_type == "smolvla":
|
||||
@@ -155,81 +151,3 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DummyRobot:
|
||||
name = "dummy"
|
||||
cameras = []
|
||||
action_features = {"foo": 1.0, "bar": 2.0}
|
||||
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
||||
is_connected = True
|
||||
|
||||
def connect(self, *args):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
def dummy_make_robot_from_config(*args, **kwargs):
|
||||
return DummyRobot()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
@skip_if_package_missing("peft")
|
||||
def test_peft_record_loads_policy(policy_type, tmp_path):
|
||||
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
|
||||
from peft import PeftModel
|
||||
|
||||
output_dir = tmp_path / f"output_{policy_type}"
|
||||
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||
|
||||
lerobot_train(
|
||||
[
|
||||
f"--policy.path={model_id}",
|
||||
"--policy.push_to_hub=false",
|
||||
"--policy.input_features=null",
|
||||
"--policy.output_features=null",
|
||||
"--peft.method=LORA",
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0, 1]",
|
||||
"--steps=1",
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
|
||||
dataset_dir = tmp_path / "eval_pusht"
|
||||
single_task = "move the table"
|
||||
loaded_policy = None
|
||||
|
||||
def dummy_record_loop(*args, **kwargs):
|
||||
nonlocal loaded_policy
|
||||
|
||||
if "dataset" not in kwargs:
|
||||
return
|
||||
|
||||
dataset = kwargs["dataset"]
|
||||
dataset.add_frame({"task": single_task})
|
||||
loaded_policy = kwargs["policy"]
|
||||
|
||||
with (
|
||||
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
|
||||
# disable record loop since we're only interested in successful loading of the policy.
|
||||
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
|
||||
# disable speech output
|
||||
patch("lerobot.utils.utils.say"),
|
||||
):
|
||||
lerobot_record(
|
||||
[
|
||||
f"--policy.path={policy_dir}",
|
||||
"--robot.type=so101_follower",
|
||||
"--robot.port=/dev/null",
|
||||
"--dataset.repo_id=lerobot/eval_pusht",
|
||||
f'--dataset.single_task="{single_task}"',
|
||||
f"--dataset.root={dataset_dir}",
|
||||
"--dataset.push_to_hub=false",
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(loaded_policy, PeftModel)
|
||||
|
||||
@@ -21,8 +21,9 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
|
||||
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.scripts.lerobot_record import RecordConfig, record
|
||||
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user