Compare commits

..

47 Commits

Author SHA1 Message Date
CarolinePascal 4445849b86 feat(depth maps writer): adding support for raw depth maps recording with image writer 2026-05-01 00:49:09 +02:00
CarolinePascal f43bf75f9b fix(viz): anchor rerun DepthImage colormap to encoder depth range 2026-05-01 00:49:09 +02:00
CarolinePascal b540fa94a9 feat(viz): render depth observations as rr.DepthImage in Viridis
log_rerun_data now accepts an optional `features` dict and uses the
`video.is_depth_map=True` info marker to detect depth observations.
Matching 2D arrays are logged as `rr.DepthImage(arr, meter=1.0,
colormap=rr.components.Colormap.Viridis)` and are never JPEG-compressed
(compression is lossy on float32 metric depth).

Detection covers both the namespaced dataset key
(e.g. `observation.depth.front`) and the raw observation keys the robot
emits (`front`, `front_depth`), so it works for both the typed
LeRobotDataset.features dict and the plain robot observation flow.

When `features` is None the previous behaviour is preserved (depth
arrays fall back to the generic `rr.Image` path), so non-depth
recordings and existing call sites are unaffected.

lerobot-record now forwards `dataset.features` so depth keys are picked
up automatically when `--display_data=true`.

Made-with: Cursor
2026-05-01 00:49:09 +02:00
CarolinePascal efad15f600 feat(record): plumb DepthEncoderConfig through lerobot-record
Surface DepthEncoderConfig and depth_encoder_defaults from
lerobot.datasets, and wire dataset.depth_encoder_config through
LeRobotDataset.create() and LeRobotDataset.resume() so depth-capable
recordings (e.g. RealSense use_depth=True) can be tuned from the CLI:

    --dataset.depth_encoder_config.depth_min=0.1
    --dataset.depth_encoder_config.depth_max=4.0
    --dataset.depth_encoder_config.vcodec=ffv1

The default factory keeps depth-stream defaults (12-bit HEVC, log
quantization), so non-depth recordings are unaffected.

Made-with: Cursor
2026-05-01 00:49:09 +02:00
CarolinePascal 407d1882a2 feat(robots/so_follower): emit + populate depth keys when use_depth
When an SO follower has a camera configured with use_depth=True (e.g.
a RealSense), the robot now exposes a paired depth feature so the
dataset records both modalities:

- _cameras_ft adds a 2D "<cam>_depth" entry alongside the 3-channel
  color shape; hw_to_dataset_features turns this into
  observation.depth.<cam> with the depth-map marker.
- get_observation reads cam.read_latest_depth() (float32 metric
  meters from the RealSense async depth API) into <cam>_depth so
  build_dataset_frame can route it.

Detection is duck-typed via getattr(..., "use_depth", False) so other
cameras without that attribute keep their RGB-only behaviour unchanged.

Made-with: Cursor
2026-05-01 00:49:09 +02:00
CarolinePascal 0d6e4f3bad feat(features): route 2D camera shapes to observation.depth.<key>
hw_to_dataset_features now treats a camera entry whose shape has
length 2 as a single-channel depth feature: it emits the feature as
"{prefix}.depth.<bare>" with names=["height", "width"] and an
info={"video.is_depth_map": True} marker so the depth-encoder branch
in LeRobotDataset is engaged. The "_depth" hardware-side suffix (if
present) is stripped so a paired RGB + depth camera ends up as
"observation.images.<cam>" + "observation.depth.<cam>".

build_dataset_frame mirrors the routing: depth feature keys read
their value from "<bare>_depth" in the raw observation dict, with
fallback to the bare name for producers that already emit
dataset-style keys.

Tests: add tests/utils/test_feature_utils.py covering the routing
of 2D vs 3D camera shapes, the paired RGB+depth case, and the
build_dataset_frame value routing.

Made-with: Cursor
2026-05-01 00:49:09 +02:00
CarolinePascal 536b29d963 feat(cameras/realsense): expose async depth in metric meters 2026-05-01 00:48:40 +02:00
CarolinePascal 2744e26593 feat(depth): wire DatasetReader to decode_depth_frames 2026-05-01 00:41:38 +02:00
CarolinePascal de64ad3f7e feat(depth): wire StreamingVideoEncoder + writer to depth encoder 2026-05-01 00:29:34 +02:00
CarolinePascal d777359662 feat(depth): plumb DepthEncoderConfig through LeRobotDataset and DatasetWriter 2026-04-30 23:55:28 +02:00
CarolinePascal 5d0a20bd9c feat(video): alias "av1" to "libsvtav1" for backward compat 2026-04-30 23:43:02 +02:00
CarolinePascal 2c796d3352 feat(depth): persist depth metadata + add reader helpers 2026-04-30 23:38:56 +02:00
CarolinePascal df1648c102 feat(video): add ffv1 to supported codecs 2026-04-30 17:32:50 +02:00
CarolinePascal 3bd96a4346 feat(depth): add depth quantization helpers and tests 2026-04-30 17:31:03 +02:00
CarolinePascal 016799dfa1 chore(format): formatting code 2026-04-30 14:42:37 +02:00
CarolinePascal 51b9038458 chore(PyAV): cleaning up PyAV utils and encoding parameters checks to stick to the minimun required tooling. 2026-04-30 14:31:08 +02:00
CarolinePascal cc9a2e5c99 chore(format): fixing formatting issues 2026-04-29 16:48:57 +02:00
CarolinePascal a2376389f9 test(new): adding new tests for encoding related features 2026-04-29 16:48:56 +02:00
CarolinePascal 57a619ab02 test(existing): adapting existing tests 2026-04-29 16:48:56 +02:00
CarolinePascal 7f624adcc5 chore(duplicate): removing duplicate get_codec_options definition 2026-04-29 16:48:56 +02:00
CarolinePascal 375cf1fdf3 feat(pyav checks): making pyav parameters checks more robust 2026-04-29 16:48:56 +02:00
CarolinePascal b2c2bb7641 feat(VideoEncoderConfig init): making VideoEncoderConfig more robust and adaptable to multiple backends 2026-04-29 16:48:56 +02:00
CarolinePascal 4a87ee1537 fix(concatenation compatibility): adding compatibility check when concatenating video files 2026-04-29 16:48:56 +02:00
CarolinePascal e44f86e516 feat(metadata): adding encoding parameters in dataset metadata 2026-04-29 16:48:56 +02:00
CarolinePascal a0e3acdb67 chore(docs): updating the docs 2026-04-29 16:46:16 +02:00
CarolinePascal 38ff579bcc feat(VideoEncoderConfig): propagating the VideoEncoderConfig in the codebase 2026-04-29 16:44:47 +02:00
CarolinePascal 479e444517 feat(VideoEncoderConfig): creating a VideoEncoderConfig to encapsulate encoding parameters 2026-04-29 16:42:14 +02:00
CarolinePascal 9787b8fa26 feat(pyav utils): adding suport for PyAV encoding parameters validation 2026-04-29 16:42:14 +02:00
CarolinePascal 71f39f6912 chore(video backend): renaming codec into video_backend in get_safe_default_video_backend() 2026-04-29 16:42:14 +02:00
Khalil Meftah b5f65e5332 Expose sarm package API and ship reward model card template (#3477)
* chore: List lerobot_rewardmodel_modelcard_template.md in MANIFEST.in

* chore: export SARMConfig, SARMRewardModel, and make_sarm_pre_post_processors from rewards.sarm.
2026-04-29 16:17:16 +02:00
Khalil Meftah cd6b43ea7a fix(train): migrate legacy RA-BC fields in train config loading (#3480) 2026-04-29 16:17:00 +02:00
Steven Palma 2236bbe7a3 fix(rollout): propagate policy-specific CLI config paramaters (#3483)
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-29 16:13:10 +02:00
Maxime Ellerbach cb0a944941 refactor(datasets): replace untyped dict with typed DatasetInfo dataclass (#3472)
* refactor(datasets): replace untyped dict with typed DatasetInfo dataclass

Introduce typed DatasetInfo dataclass to replace untyped dict representation of info.json.

Changes:
- Add DatasetInfo dataclass with explicit fields and validation
- Implement __post_init__ for shape conversion (list ↔ tuple)
- Add dict-style compatibility layer (__getitem__, __setitem__, .get())
- Add from_dict() and to_dict() for JSON serialization
- Update io_utils to use load_info/write_info with DatasetInfo
- Update dataset utilities and metadata to use attribute access
- Remove aggregate.py dict-style field access
- Add tests fixture support for DatasetInfo

Benefits:
- Type safety with IDE auto-completion
- Validation at construction time
- Explicit schema documentation

* fix pre-commit

* update docstring inside DatasetInfo.from_dict()

* sorts the unknown to have deterministic output

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* refactoring the last few old fieds


* fix crop dataset roi type mismatch


* use consistantly int for data and video_files_size_in_mb

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: jjolla93 <jjolla93@gmail.com>
2026-04-28 18:40:30 +02:00
Khalil Meftah 8a3d64033f Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2026-04-28 17:56:24 +02:00
Steven Palma 03ee50e08f chore(ci): bump docs workflows (#3476) 2026-04-28 15:06:44 +02:00
Steven Palma ca87ccd941 feat(rollout): decouple policy deployment from data recording with new lerobot-rollout CLI (#3413)
* feat(scripts): lerobot-rollout

* fix(rollout) require dataset in dagger + use duration too

* fix(docs): dagger num_episodes

* test(rollout): fix expectations

* fix(rollout): features check

* fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config

* docs(rollout): edit rename_map instructions

* chore(rollout): multiple minor improvements

* chore(rollout): address coments + minor improvements

* fix(rollout): enable default

* fix(tests): default value RTCConfig

* fix(rollout): robot_observation_processor and notify_observation at policy frequency instead of interpolator rate

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): prevent relativeactions with sync inference engine

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): rtc reanchor to non normalized state

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(rollout): fixing the episode length to use hwc (#3469)

also reducing default length to 5 minutes

* feat(rollout): go back to initial position is now a config

* fix(rollout): properly propagating video_files_size_in_mb to lerobot_dataset (#3470)

* chore(rollout): note about dagger correction stage

* chore(docs): update comments and docstring

* fix(test): move rtc relative out of rollout module

* fix(rollout): address the review comments

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime.ellerbach@huggingface.co>
2026-04-28 00:57:35 +02:00
Steven Palma 77352c495c chore(dependencies): update uv.lock (#3437)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-04-27 23:15:46 +02:00
Steven Palma 05a5223885 fix(pi): avoid peak RAM in PiGemma construction by freeing replaced submodules (#3454)
Co-Authored-By: Daiki Kamata <daiki.kamata@access-company.com>
Co-Authored-By: Jack Vial <jackvial@users.noreply.github.com>
Co-Authored-By: Ajay Anubolu <AjAnubolu@users.noreply.github.com>
Co-Authored-By: Finn F. <F-Fer@users.noreply.github.com>
2026-04-24 17:50:12 +02:00
Steven Palma 580d818aa9 fix(dataset): no default overwrite in lerobot tool recompute stats (#3452) 2026-04-24 15:07:19 +02:00
Steven Palma 587aa82021 fix(imports): realsense import name is platform dependent (#3451) 2026-04-24 12:55:38 +02:00
Chuyao Shen 12b88fce02 not use dataclass (#3414)
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-24 11:26:59 +02:00
masato-ka fc6c94c82a fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in… (#3419)
* fix(sarm): handle BaseModelOutputWithPooling from transformers 5.x in CLIP encoding

In transformers 5.x, CLIPModel.get_image_features() and get_text_features()
return BaseModelOutputWithPooling instead of a plain torch.FloatTensor.
Added isinstance check to extract pooler_output when the return value is not
a tensor, maintaining backward compatibility with transformers 4.x.

Fixes AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'detach'

* Adding assertion check for pooler_output of CLIP. This change is response to below comment.
https://github.com/huggingface/lerobot/pull/3419#discussion_r3112594387

* Adding assertion check for pooler_output of CLIP. This change is response to below comment. Change to simple check and rise
https://github.com/huggingface/lerobot/pull/3419#discussion_r3126953776

---------
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-04-23 16:26:58 +02:00
Steven Palma 1add460678 fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442)
* Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT

When action_is_pad masks out padded timesteps, the subsequent .mean()
still divides by the total element count (including zeroed-out padding),
underestimating the loss. With 60-70% padding this can cut the effective
gradient signal by 2-3x.

Replace mask-then-mean with mask-then-sum / valid-count for all three
affected policies. TDMPC is not affected because it sums over time
before averaging over batch.

Fixes #3353

* linting

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* Update src/lerobot/policies/diffusion/modeling_diffusion.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py

Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* apply ACT loss normalization suggestion from review

Divide by num_valid (timesteps * action_dim) instead of just timesteps,
matching the diffusion/multi_task_dit fix. Addresses review from
@whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791).

* fix(test): update safetensor act

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com>
Co-authored-by: Maxime Ellerbach <maxime@ellerbach.net>
Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
2026-04-23 15:23:54 +02:00
Qi Jia 4587c2b648 fix xvla docs (#3291)
Co-authored-by: Qi Jia <kaufou@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-04-23 14:50:32 +02:00
whats2000 2236cdb302 fix(smolvla): correct loss normalization for padded actions (#3434)
Apply the same per-scalar-mean fix to SmolVLA that #3377 landed for
ACT / Diffusion / MultiTaskDiT. The pre-patch form applies the
`action_is_pad` mask to zero out padded timesteps, then calls `.mean()`
(or `.mean(dim=(1, 2))`). Because `.mean()` divides by the total number
of elements including the zeroed padding, the loss is diluted by the
padding fraction.

Fixed by normalizing only over valid (non-padded) scalar entries:

    num_valid = ((~actions_is_pad).sum(...) * losses.shape[-1]).clamp_min(1)
    loss = losses.sum(...) / num_valid

`clamp_min(1)` preserves the all-padded-batch edge case (0/1 = 0). Both
reduction paths are updated. Behavior when `action_is_pad` is missing is
unchanged (`losses.mean()`).

Empirical A/B on aloha_sim_transfer_cube_human (chunk_size=40, batch=2,
30 steps, fixed seed, GB200) shows `loss_A / loss_B = 0.9672 (±0.088)` —
same direction and magnitude as PR #3377's `loss_A / loss_C ≈ 0.96` for
ACT. Heavier-padding recipes will see a larger gap.

Refs: #3353 (original report for ACT), #3377 (fix for the other three
policies).
2026-04-23 10:34:11 +02:00
Steven Palma 7c2466979e chore(dependencies): update uv.lock (#3408)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-04-22 16:38:51 +02:00
Pepijn 39b966e20a docs(agents): add AGENT_GUIDE.md for user facing agent (#3430)
* docs(agents): add AGENT_GUIDE.md with SO-101, data, policy, training, eval guidance

Adds an agent-facing companion to AGENTS.md that helps AI agents (Cursor,
Claude, ChatGPT, etc.) guide end-users through LeRobot without needing to
re-read every doc:

- Mandatory "ask the user first" block (goal, hardware, GPU, skill level)
- SO-101 end-to-end cheat-sheet: install -> calibrate -> record -> train -> eval
- Data-collection tips distilled from the folding project (practice before
  you record, quality > speed, start constrained then add diversity)
- Policy decision table with indicative profiling numbers (update ms, peak
  GPU mem) and AdamW-vs-SGD caveats
- Training duration guidance: 5-10 epoch rule, epoch<->step conversion,
  scheduler/checkpoint scaling with --steps, SmolVLA unfreeze tip
- Real-robot eval via lerobot-record --policy.path and sim eval via
  lerobot-eval, including the pre-baked docker/Dockerfile.benchmark.* images

AGENTS.md gets a short pointer to AGENT_GUIDE.md at the top.
CLAUDE.md (symlink to AGENTS.md) inherits the pointer automatically.

Made-with: Cursor

* docs(agents): recommend 2 cameras (front + wrist) as default

Made-with: Cursor

* docs(agents): add Feetech wiring check and broaden visualizer note

Made-with: Cursor

* docs(agents): clarify Feetech LED behavior (steady-on, not flash)

Made-with: Cursor

* docs(agents): expand Feetech troubleshooting (blinking LED, 5V vs 12V variants)

Made-with: Cursor

* docs(agents): tighten Feetech LED wording

Made-with: Cursor
2026-04-22 11:54:19 +02:00
217 changed files with 12353 additions and 10230 deletions
@@ -33,7 +33,7 @@ jobs:
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with:
package_name: lerobot
secrets:
+2 -2
View File
@@ -55,7 +55,7 @@ jobs:
github.repository == 'huggingface/lerobot'
permissions:
contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with:
commit_sha: ${{ github.sha }}
package: lerobot
@@ -78,7 +78,7 @@ jobs:
permissions:
contents: read
pull-requests: write
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
+2
View File
@@ -1,5 +1,7 @@
This file provides guidance to AI agents when working with code in this repository.
> **User-facing help → [`AGENT_GUIDE.md`](./AGENT_GUIDE.md)** (SO-101 setup, recording, picking a policy, training duration, eval — with copy-pasteable commands).
## Project Overview
LeRobot is a PyTorch-based library for real-world robotics, providing datasets, pretrained policies, and tools for training, evaluation, data collection, and robot control. It integrates with Hugging Face Hub for model/dataset sharing.
+410
View File
@@ -0,0 +1,410 @@
# AGENT_GUIDE.md — LeRobot Helper for AI Agents & Users
This file is a practical, copy-paste-friendly companion for any AI agent (Cursor, Claude, ChatGPT, Codex, etc.) helping a user work with LeRobot. It complements [`AGENTS.md`](./AGENTS.md) (dev/contributor context) with **user-facing guidance**: how to start, what to train, how long, how to record, and how to calibrate an SO-101.
---
## 1. Start here — ask the user first (MANDATORY)
Before suggesting any command, an agent MUST ask the user at least these questions and wait for answers:
1. **What's your goal?** (e.g. "teach my SO-101 to fold a cloth", "train a policy on an existing HF dataset", "contribute a PR", "understand the codebase")
2. **What hardware do you have?**
- Robot: none / SO-100 / SO-101 / Koch / LeKiwi / Reachy / other
- Teleop: leader arm / phone / keyboard / gamepad / none
- Cameras: how many, resolution, fixed or moving?
3. **What machine will you train on?**
- GPU model + VRAM (e.g. "laptop 3060 6 GB", "RTX 4090 24 GB", "A100 80 GB", "CPU only")
- OS: macOS / Linux / Windows
4. **Skill level & time budget?** First time, some ML, experienced? Hours, days, a weekend?
5. **Do you already have a dataset?** Yes (HF repo id?) / no / want to record one
6. **How can I help right now?** (pick one concrete next step)
Only after you have answers, propose a concrete path. If something is ambiguous, ask again rather than guessing. Bias toward **the simplest thing that works** for the user's hardware and goal.
---
## 2. LeRobot in 60 seconds
LeRobot = **datasets + policies + envs + robot control**, unified by a small set of strong abstractions.
- **`LeRobotDataset`** — episode-aware dataset (video or images + actions + state), loadable from the Hub or disk.
- **Policies** (`ACT`, `Diffusion`, `SmolVLA`, `π0`, `π0.5`, `Wall-X`, `X-VLA`, `VQ-BeT`, `TD-MPC`, …) — all inherit `PreTrainedPolicy` and can be pushed/pulled from the Hub.
- **Processors** — small composable transforms between dataset → policy → robot.
- **Envs** (sim) and **Robots** (real) — same action/observation contract so code swaps cleanly.
- **CLI** — `lerobot-record`, `lerobot-train`, `lerobot-eval`, `lerobot-teleoperate`, `lerobot-calibrate`, `lerobot-find-port`, `lerobot-setup-motors`, `lerobot-replay`.
See [`AGENTS.md`](./AGENTS.md) for repo architecture.
---
## 3. Quickstart paths (pick one)
### Path A — "I have an SO-101 and want my first trained policy"
Go to §4 (SO-101 end-to-end), then §5 (data tips), then §6 (pick a policy — likely **ACT**), then §7 (how long), then §8 (eval).
### Path B — "No hardware, I want to train on an existing dataset"
Skip §4. Pick a policy in §6, pick a duration in §7, then run `lerobot-train` per §4.9 with a Hub `--dataset.repo_id` and an `--env.type` for eval. Finish with §8.
### Path C — "I just want to understand the codebase"
Read §2 above, then `AGENTS.md` "Architecture", then open `src/lerobot/policies/act/` and `src/lerobot/datasets/lerobot_dataset.py` as canonical examples.
---
## 4. SO-101 end-to-end cheat-sheet
Full details in [`docs/source/so101.mdx`](./docs/source/so101.mdx) and [`docs/source/il_robots.mdx`](./docs/source/il_robots.mdx). Minimum commands in order. Confirm arms are assembled + powered before issuing.
**4.1 Install**
```bash
pip install 'lerobot[feetech]' # SO-100/SO-101 motor stack
# pip install 'lerobot[all]' # everything
# pip install 'lerobot[aloha,pusht]' # specific features
# pip install 'lerobot[smolvla]' # add SmolVLA deps
git lfs install && git lfs pull
hf auth login # required to push datasets/policies
```
Contributors can alternatively use `uv sync --locked --extra feetech` (see `AGENTS.md`).
**4.2 Find USB ports** — run once per arm, unplug when prompted.
```bash
lerobot-find-port
```
macOS: `/dev/tty.usbmodem...`; Linux: `/dev/ttyACM0` (may need `sudo chmod 666 /dev/ttyACM0`).
**4.3 Setup motor IDs & baudrate** (one-time, per arm)
```bash
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
```
**4.4 Calibrate** — center all joints, press Enter, sweep each joint through its full range. The `id` is the calibration key — reuse it everywhere.
```bash
lerobot-calibrate --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower
lerobot-calibrate --teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader
```
**4.5 Teleoperate** (sanity check, no recording)
```bash
lerobot-teleoperate \
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true
```
> **Feetech timeout / comms error on SO-100 / SO-101?** Before touching software, check the **red motor LEDs** on the daisy chain.
>
> - **All steady red, gripper → base chain** → wiring OK.
> - **One or more motors dark / chain stops mid-way** → wiring issue: reseat the 3-pin cables, check the controller-board power supply, and make sure each motor is fully clicked in.
> - **LEDs blinking** → the motor is in an **error state**: usually overload (forcing a joint past its limit) **or wrong power supply voltage**. SO-100 / SO-101 ship in two variants — a **5 V / 7.4 V** build and a **12 V** build — they are NOT interchangeable. Using a 12 V PSU on a 5 V / 7.4 V arm (or vice-versa) will trip this error; confirm your motor variant before powering up.
>
> Most "timeout" errors are physical, not code.
**4.6 Record a dataset** — keys: **→** next, **←** redo, **ESC** finish & upload.
```bash
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
lerobot-record \
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/my_task \
--dataset.single_task="<describe the task in one sentence>" \
--dataset.num_episodes=50 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--display_data=true
```
**4.7 Visualize****always** do this before training. Look for missing frames, camera blur, unreachable targets, inconsistent object positions.
After upload: https://huggingface.co/spaces/lerobot/visualize_dataset → paste `${HF_USER}/my_task`. Works for **any LeRobot-formatted Hub dataset** — use it to scout other datasets, inspect episode quality, or debug your own data before retraining.
**4.8 Replay an episode** (sanity check)
```bash
lerobot-replay --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
--dataset.repo_id=${HF_USER}/my_task --dataset.episode=0
```
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. See §6/§7 for policy and duration.
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/my_task \
--policy.type=act \
--policy.device=cuda \
--output_dir=outputs/train/act_my_task \
--job_name=act_my_task \
--batch_size=8 \
--wandb.enable=true \
--policy.repo_id=${HF_USER}/act_my_task
```
**4.10 Evaluate on the real robot** — compare success rate to a teleoperated baseline.
```bash
lerobot-record \
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_my_task \
--dataset.single_task="<same task description as training>" \
--dataset.num_episodes=10 \
--policy.path=${HF_USER}/act_my_task
```
---
## 5. Data collection tips (beginner → reliable policy)
Good data beats clever models. Adopt these defaults and deviate only with evidence.
### 5.1 Setup & ergonomics
- **Fix the rig and cameras** before touching the software. If the rig vibrates or the operator gets frustrated, fix that first — more bad data won't help.
- **Lighting matters more than resolution.** Diffuse, consistent light. Avoid moving shadows.
- **"Can you do the task from the camera view alone?"** If no, your cameras are wrong. Fix before recording.
- Enable **action interpolation** for rollouts when available for smoother trajectories.
### 5.2 Practice before you record
- Do 510 demos without recording. Build a deliberate, repeatable strategy.
- Hesitant or inconsistent demos teach the model hesitation.
### 5.3 Quality over speed
Deliberate, high-quality execution beats fast sloppy runs. Optimize for speed only **after** strategy is dialed in — never trade quality for it.
### 5.4 Consistency within and across episodes
Same grasp, approach vector, and timing. Coherent strategies are much easier to learn than wildly varying movements.
### 5.5 Start small, then extend (the golden rule)
- **First 50 episodes = constrained version** of the task: one object, fixed position, fixed camera setup, one operator.
- Train a quick ACT model. See what fails.
- **Then add diversity** along one axis at a time: more positions → more lighting → more objects → more operators.
- Don't try to collect the "perfect dataset" on day one. Iterate.
### 5.6 Policy choice for beginners
- **Laptop / first time / want results fast → ACT.** Works surprisingly well, trains fast even on a laptop GPU.
- **Bigger GPU / language-conditioned / multi-task → SmolVLA.** Unfreezing the vision encoder (see §7) is a big win here.
- Defer π0 / π0.5 / Wall-X / X-VLA until you have a proven ACT baseline and a 20+ GB GPU.
### 5.7 Recommended defaults for your first task
| Setting | Value |
| ---------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
| Episodes | **50** to start, scale to 100300 after first training |
| Episode length | 2045 s (shorter is fine for grasp/place) |
| Reset time | 10 s |
| FPS | 30 |
| Cameras | **2 cameras recommended**: 1 fixed front + 1 wrist. Multi-view often outperforms single-view. A single fixed camera also works to keep things simple. |
| Task description | Short, specific, action-phrased sentence |
### 5.8 Troubleshooting signal
- Policy fails at one specific stage → record 1020 more episodes **targeting that stage**.
- Policy flaps / oscillates → likely inconsistent demos, or need more training; re-record worst episodes (use **←** to redo).
- Policy ignores the object → camera framing or lighting issue, not a model issue.
See also: [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset).
---
## 6. Which policy should I train?
Match the policy to the user's **GPU memory** and **time budget**. Numbers below come from an internal profiling run (one training update per policy). They are **indicative only** — see caveats.
### 6.1 Profiling snapshot (indicative)
All policies typically train for **510 epochs** (see §7).
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
| `diffusion` | 4 | 168.6 | 4.94 | Multi-modal action distributions; needs mid-range GPU. |
| `smolvla` | 1 | 357.8 | 3.93 | Language-conditioned, multi-task, small VLA. **Unfreeze vision encoder for big gains** (see §7). |
| `xvla` | 1 | 731.6 | 15.52 | Large VLA, multi-task. |
| `wall_x` | 1 | 716.5 | 15.95 | Large VLA with world-model objective. |
| `pi0` | 1 | 940.3 | 15.50 | Strong large VLA baseline (Physical Intelligence). |
| `pi05` | 1 | 1055.8 | 16.35 | Newer π policy; similar footprint to `pi0`. |
**Critical caveats:**
- **Optimizer:** measured with **SGD**. LeRobot's default is **AdamW**, which keeps extra optimizer state → **peak memory will be noticeably higher** with the default, especially for `pi0`, `pi05`, `wall_x`, `xvla`.
- **Batch size:** the large policies were profiled at batch 1. In practice use a **larger batch** for stable training (see §7.4). Memory scales roughly linearly with batch.
### 6.2 Decision rules
- **< 8 GB VRAM (laptop, 3060, M-series Mac):** → `act`. Maybe `diffusion` if you have ~68 GB free.
- **1216 GB VRAM (4070/4080, A4000):** → `smolvla` with defaults, or `act`/`diffusion` with larger batch. `pi0`/`pi05`/`wall_x`/`xvla` feasible only with small batch + gradient accumulation.
- **24+ GB VRAM (3090/4090/A5000):** → any policy. Prefer `smolvla` (unfrozen) for multi-task; `act` for single-task grasp-and-place (still often the best ROI). Could experiment with `pi0` or `pi05` or `xvla`
- **80 GB (A100/H100):** → any, with healthy batch. `pi05`, `xvla`, `wall_x` become comfortable.
- **CPU only:** → don't train here. Use Google Colab (see [`docs/source/notebooks.mdx`](./docs/source/notebooks.mdx)) or a rented GPU.
---
## 7. How long should I train?
Robotics imitation learning usually converges in a **few epochs over the dataset**, not hundreds of thousands of raw steps. Think **epochs first**, then translate to steps.
### 7.1 Rule of thumb
- **Typical total: 510 epochs.** Start at 5, eval, then decide if more helps.
- Very small datasets (< 30 episodes) may want slightly more epochs — but first, **collect more data**.
- VLAs with a pretrained vision backbone typically need **fewer** epochs than training from scratch.
### 7.2 Steps ↔ epochs conversion
```
total_frames = sum of frames over all episodes # e.g. 50 eps × 30 fps × 30 s ≈ 45,000
steps_per_epoch = ceil(total_frames / batch_size)
total_steps = epochs × steps_per_epoch
```
Examples for `--batch_size=8`:
| Dataset size | Frames | Steps / epoch | 5 epochs | 10 epochs |
| ----------------------- | ------: | ------------: | -------: | --------: |
| 50 eps × 30 s @ 30 fps | 45,000 | ~5,625 | 28k | 56k |
| 100 eps × 30 s @ 30 fps | 90,000 | ~11,250 | 56k | 113k |
| 300 eps × 30 s @ 30 fps | 270,000 | ~33,750 | 169k | 338k |
Pass the resulting total with `--steps=<N>`; eval at intermediate checkpoints (`outputs/train/.../checkpoints/`).
### 7.3 Per-policy starting points (single-task, ~50 episodes)
| Policy | Batch | Steps (first run) | Notes |
| -------------- | ----: | ----------------: | ----------------------------------------------------------------- |
| `act` | 816 | 30k80k | Usually converges under 50k for single-task. |
| `diffusion` | 816 | 80k150k | Benefits from longer training than ACT. |
| `smolvla` | 48 | 30k80k | Pretrained VLM → converges fast. |
| `pi0` / `pi05` | 14 | 30k80k | Memory-bound; use gradient accumulation for effective batch ≥ 16! |
### 7.4 Batch size guidance
- **Bigger batch is preferable** for stable gradients on teleop data.
- If GPU memory is the bottleneck, use **gradient accumulation** to raise _effective_ batch without raising peak memory.
- Scale **learning rate** gently with batch; most LeRobot defaults work fine for a 24× batch change.
### 7.5 Scale LR schedule & checkpoints with `--steps`
LeRobot's default schedulers (e.g. SmolVLA's cosine decay) use `scheduler_decay_steps=30_000`, which is sized for long training runs. When you shorten training (e.g. 5k10k steps on a small dataset), **scale the scheduler down to match** — otherwise the LR stays near the peak and never decays. Same for checkpoint frequency.
```bash
lerobot-train ... \
--steps=5000 \
--policy.scheduler_decay_steps=5000 \
--save_freq=5000
```
Rule of thumb: set `scheduler_decay_steps ≈ steps`, and `save_freq` to whatever granularity you want for eval (e.g. every 1k5k steps). Match `scheduler_warmup_steps` proportionally if your run is very short.
### 7.6 SmolVLA: unfreeze the vision encoder for real gains
SmolVLA ships with `freeze_vision_encoder=True`. Unfreezing usually **improves performance substantially** on specialized tasks, at the cost of more VRAM and slower steps. Enable with:
```bash
lerobot-train ... --policy.type=smolvla \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false
```
### 7.7 Signals to stop / keep going
- Train loss plateaus → stop, save a Hub checkpoint.
- Train loss still dropping and you're under 10 epochs → keep going.
---
## 8. Evaluation & benchmarks
Two flavors of evaluation:
### 8.1 Real-robot eval (SO-101, etc.)
Reuse `lerobot-record` with `--policy.path` to run the trained policy on-robot and save the run as an eval dataset. Convention: prefix the dataset with `eval_`.
```bash
lerobot-record \
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_my_task \
--dataset.single_task="<same task description used during training>" \
--dataset.num_episodes=10 \
--policy.path=${HF_USER}/act_my_task
```
Report success rate across episodes. Compare to a teleoperated baseline and to an earlier checkpoint to catch regressions.
### 8.2 Sim-benchmark eval
For policies trained on sim datasets (PushT, Aloha, LIBERO, MetaWorld, RoboCasa, …) use `lerobot-eval` against the matching `env.type`:
```bash
lerobot-eval \
--policy.path=${HF_USER}/diffusion_pusht \
--env.type=pusht \
--eval.n_episodes=50 \
--eval.batch_size=10 \
--policy.device=cuda
```
- Use `--policy.path=outputs/train/.../checkpoints/<step>/pretrained_model` for local checkpoints.
- `--eval.n_episodes` should be ≥ 50 for a stable success-rate estimate.
- Available envs live in `src/lerobot/envs/`. See [`docs/source/libero.mdx`](./docs/source/libero.mdx), [`metaworld.mdx`](./docs/source/metaworld.mdx), [`robocasa.mdx`](./docs/source/robocasa.mdx), [`vlabench.mdx`](./docs/source/vlabench.mdx) for specific benchmarks.
- To add a new benchmark, see [`docs/source/adding_benchmarks.mdx`](./docs/source/adding_benchmarks.mdx) and [`envhub.mdx`](./docs/source/envhub.mdx).
### 8.2b Dockerfiles for benchmark eval
Benchmark envs have native dependencies that are painful to install locally. The repo ships **pre-baked Dockerfiles** for each supported benchmark — use these to run `lerobot-eval` in a reproducible environment:
| Benchmark | Dockerfile |
| ----------- | -------------------------------------------------------------------------------------- |
| LIBERO | [`docker/Dockerfile.benchmark.libero`](./docker/Dockerfile.benchmark.libero) |
| LIBERO+ | [`docker/Dockerfile.benchmark.libero_plus`](./docker/Dockerfile.benchmark.libero_plus) |
| MetaWorld | [`docker/Dockerfile.benchmark.metaworld`](./docker/Dockerfile.benchmark.metaworld) |
| RoboCasa | [`docker/Dockerfile.benchmark.robocasa`](./docker/Dockerfile.benchmark.robocasa) |
| RoboCerebra | [`docker/Dockerfile.benchmark.robocerebra`](./docker/Dockerfile.benchmark.robocerebra) |
| RoboMME | [`docker/Dockerfile.benchmark.robomme`](./docker/Dockerfile.benchmark.robomme) |
| RoboTwin | [`docker/Dockerfile.benchmark.robotwin`](./docker/Dockerfile.benchmark.robotwin) |
| VLABench | [`docker/Dockerfile.benchmark.vlabench`](./docker/Dockerfile.benchmark.vlabench) |
Build and run (adapt to your benchmark):
```bash
docker build -f docker/Dockerfile.benchmark.robomme -t lerobot-bench-robomme .
docker run --gpus all --rm -it \
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
lerobot-bench-robomme \
lerobot-eval --policy.path=<your_policy> --env.type=<env> --eval.n_episodes=50
```
See [`docker/README.md`](./docker/README.md) for base-image details.
### 8.3 Target success rates
Single-task grasp-and-place with 50 clean episodes: ACT should reach **> 70% success** on the training configuration. Less → data problem (see §5), not model problem. Expect a drop when generalizing to new positions — scale episodes or diversity to recover.
---
## 9. Further reading & resources
- **Getting started:** [`installation.mdx`](./docs/source/installation.mdx) · [`il_robots.mdx`](./docs/source/il_robots.mdx) · [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets)
- **Per-policy docs:** browse [`docs/source/*.mdx`](./docs/source/) (policies, hardware, benchmarks, advanced training).
- **Community:** [Discord](https://discord.com/invite/s3KuuzsPFb) · [Hub `LeRobot` tag](https://huggingface.co/datasets?other=LeRobot) · [Dataset visualizer](https://huggingface.co/spaces/lerobot/visualize_dataset)
> Keep this file current. If you learn a rule that would prevent a class of user mistakes, add it here and in [`AGENTS.md`](./AGENTS.md).
+1
View File
@@ -1,3 +1,4 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
-6
View File
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
--env.episode_length=5 \
--eval.n_episodes=1 \
--eval.batch_size=1
# E2E annotation pipeline smoke test against a tiny in-memory fixture
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
# backend, so it does not require a real model checkpoint or GPU.
annotation-e2e:
uv run python -m tests.annotations.run_e2e_smoke
+8 -4
View File
@@ -39,6 +39,7 @@ from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
VideoEncoderConfig,
decode_video_frames,
encode_video_frames,
)
@@ -251,10 +252,13 @@ def benchmark_encoding_decoding(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=encoding_cfg["vcodec"],
pix_fmt=encoding_cfg["pix_fmt"],
g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
camera_encoder_config=VideoEncoderConfig(
vcodec=encoding_cfg["vcodec"],
pix_fmt=encoding_cfg["pix_fmt"],
g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
preset=encoding_cfg.get("preset"),
),
# fast_decode=encoding_cfg.get("fastdecode"),
overwrite=True,
)
+4 -4
View File
@@ -31,10 +31,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: language_and_recipes
title: Language Columns and Recipes
- local: annotation_pipeline
title: Annotation Pipeline
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
@@ -63,6 +61,8 @@
title: SARM
title: "Reward Models"
- sections:
- local: inference
title: Policy Deployment (lerobot-rollout)
- local: async
title: Use Async Inference
- local: rtc
+1 -1
View File
@@ -90,6 +90,6 @@ lerobot-record \
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--policy.path=${HF_USER}/act_policy
```
-149
View File
@@ -1,149 +0,0 @@
# Annotation Pipeline
`lerobot-annotate` populates the two language columns introduced by the
[Language Columns and Recipes](./language_and_recipes) page —
`language_persistent` and `language_events` — directly into
`data/chunk-*/file-*.parquet`. There is no flavor namespace and no sidecar
file tree: multiple revisions of a dataset mean multiple dataset copies.
## What the pipeline produces
Three modules write into a per-episode staging tree, then a single writer
rewrites the data shards in place:
| Style / atom | Column | Module |
| ------------------------------------------- | --------------------- | -------- |
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
| `interjection` | `language_events` | Module 2 |
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
The writer also adds a dataset-level `tools` column carrying the JSON schema
for the `say` tool call, and drops the legacy `subtask_index` column.
## How to run it locally or on SLURM
Install the extra and invoke the console script:
```bash
uv sync --extra annotations
uv run lerobot-annotate \
--repo_id=imstevenpmwork/super_poulain_draft \
--vlm.backend=vllm \
--vlm.model_id=Qwen/Qwen3.6-27B-FP8 \
--vlm.tensor_parallel_size=2
```
The pipeline attaches actual camera footage to every Module 1/2/3 prompt
by default, decoded from the dataset's first `observation.images.*`
stream. Override with `--vlm.camera_key=observation.images.<name>` to
pin a specific viewpoint. Datasets with no video tracks fall back to
text-only prompts automatically.
**Module 1 sees the whole episode as one video block.** Subtask
decomposition gets a `{"type":"video", "video":[<frames>]}` block
covering the entire demonstration; Qwen-VL pools temporally on its own
and decides where to cut. There is no keyframe stride or count knob —
`--module_1.max_video_frames` (default 32) only caps the frames packed
into the video block as a model-capacity bound. Module 2 attaches a
single still frame at the interjection timestamp; Module 3 attaches the
exact emission frame to each VQA pair.
The executor picks `LocalPipelineExecutor` for small datasets and
`SlurmPipelineExecutor` for large ones based on
`--executor.auto_threshold` (default 32 episodes). Force local with
`--executor.force_local=true`. SLURM jobs honour `--executor.slurm_partition`,
`--executor.slurm_gpus`, and `--executor.slurm_time`.
## Style-to-recipe consumer mapping
The pipeline produces exactly the styles consumed by
`src/lerobot/configs/recipes/pi05_hirobot.yaml`:
- `low_level_execution`, `high_level_subtask`, `memory_update` consume
`subtask`/`plan`/`memory` from `language_persistent`.
- `user_interjection_response` consumes `interjection` events plus the
paired speech atom (merged into one assistant target turn via
`tool_calls_from`) and the same-timestamp `plan` refresh.
- `ask_vqa` consumes the `(vqa, user)` and `(vqa, assistant)` pairs from
`language_events`.
## Why the design is scoped to the canonical recipe
Two things drive the scope:
1. **Persistent state vs exact-event split.** Persistent rows (`subtask`,
`plan`, `memory`) broadcast per episode and answer "what state is in
force at this frame?". Event rows (`interjection`, `vqa`, speech) only
appear on the exact frame whose timestamp matches the emission. The
pipeline writes timestamps taken straight from the source parquet — no
floating-point recomputation.
2. **One Qwen-VL pass.** All three modules share a single VLM client
(vLLM if available, transformers fallback) so the cost is one model
load per dataset, not three.
## Module independence and staged reruns
Each module writes its raw output to
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
prompt iteration cheap — re-running one module overwrites only its own
JSONL file before the writer composes the final parquet. Modules can be
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
test them in isolation.
## Validation/report checks before final write
Before the writer runs, `StagingValidator` checks:
- exact frame-timestamp alignment for every event row;
- no orphan speech / interjection pairs;
- `plan` is refreshed at every interjection timestamp;
- `memory` rows fall on subtask boundaries (warning, not error);
- VQA assistant `content` parses as JSON in one of the
bbox / keypoint / count / attribute / spatial shapes;
- every row routes to the column dictated by `column_for_style(style)`.
Errors abort the writer (`--skip_validation=true` overrides for debugging).
## Paper inspirations per module
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
atom granularity ("pick up one piece of lettuce", "place bowl to box");
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
what" detail.
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
compression directive: keep only minimal relevant information; functional
outcomes preserved, specific attributes dropped.
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
situated correction, specific constraint, preference. Speech is a
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
arguments:{text:...}}}]`).
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
keypoints) and Steerable Policies' multi-abstraction grounding.
Future maintainers should adjust the prompt templates in
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
references rather than rewriting from scratch.
## Compute and list-size estimates
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
O(`max_interjections_per_episode`) Module 2 calls, and
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
KB at most (parquet dictionary-encodes one entry per episode);
`language_events` is empty on most frames and is bounded by the number of
emissions, not `num_frames × num_emissions`.
## Reproducibility via seed and prompt hashes
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
timestamps and VQA question types. Combined with the deterministic prompt
templates checked into `prompts/`, two runs at the same seed against the
same dataset and the same model checkpoint produce byte-identical staging
artifacts. Prompt edits are recorded by file hash; future tooling can pin
expected `(seed, prompt_hash)` pairs into the dataset card.
+277
View File
@@ -0,0 +1,277 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+1 -1
View File
@@ -194,7 +194,7 @@ lerobot-record \
--dataset.single_task="Navigate around obstacles" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--display_data=true
```
+1 -1
View File
@@ -123,7 +123,7 @@ lerobot-record \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10
+19 -21
View File
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
### Teleoperator Requirements
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
- Enable/disable torque programmatically
- Move to target positions (to mirror the robot state when pausing)
**Compatible teleoperators in the current `examples/hil` scripts:**
**Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
## Script
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
| Mode | Flag | Models |
| ------------------------ | -------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
| Mode | Flag | Models |
| ------------------------ | ---------------------- | --------------------- |
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
---
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
**Standard inference (ACT, Diffusion Policy):**
```bash
python examples/hil/hil_data_collection.py \
lerobot-rollout --strategy.type=dagger \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-dataset \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--strategy.num_episodes=50 \
--interpolation_multiplier=2
```
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
For models with high inference latency, enable RTC for smooth execution:
```bash
python examples/hil/hil_data_collection.py \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
lerobot-rollout --strategy.type=dagger \
--inference.type=rtc \
--inference.rtc.execution_horizon=20 \
--inference.rtc.max_guidance_weight=5.0 \
--inference.rtc.prefix_attention_schedule=LINEAR \
--robot.type=bi_openarm_follower \
--robot.left_arm_config.port=can1 \
--robot.left_arm_config.side=left \
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/hil-rtc-dataset \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
--dataset.fps=30 \
--dataset.episode_time_s=1000 \
--dataset.num_episodes=50 \
--strategy.num_episodes=50 \
--interpolation_multiplier=3
```
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
+2 -2
View File
@@ -232,7 +232,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--display_data=true
```
@@ -278,6 +278,6 @@ lerobot-record \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
```
+27 -106
View File
@@ -193,7 +193,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--dataset.encoder_threads=2
```
</hfoption>
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
<hfoptions id="eval">
<hfoption id="Command">
<hfoption id="Base mode (no recording)">
```bash
lerobot-record \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--robot.id=my_awesome_follower_arm \
--display_data=false \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_awesome_leader_arm \
--policy.path=${HF_USER}/my_policy
--task="Put lego brick into the transparent box" \
--duration=60
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.policies.act import ACTPolicy
from lerobot.policies import make_pre_post_processors
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
# Connect the robot
robot.connect()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
)
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Run the policy inference loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
)
dataset.save_episode()
# Clean up
robot.disconnect()
dataset.push_to_hub()
<hfoption id="Sentry mode (with recording)">
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM1 \
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/eval_so100 \
--dataset.single_task="Put lego brick into the transparent box" \
--duration=600
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
The `--strategy.type` flag selects the execution mode:
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+261
View File
@@ -0,0 +1,261 @@
# Policy Deployment (lerobot-rollout)
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
## Quick Start
No extra dependencies are needed beyond your robot and policy extras.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=lerobot/act_koch_real \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--task="pick up cube" \
--duration=30
```
This runs the policy for 30 seconds with no recording.
---
## Strategies
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
### Base (`--strategy.type=base`)
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
```bash
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Put lego brick into the box" \
--duration=60
```
| Flag | Description |
| ---------------- | ------------------------------------------------------ |
| `--duration` | Run time in seconds (0 = infinite) |
| `--task` | Task description passed to the policy |
| `--display_data` | Stream observations/actions to Rerun for visualization |
### Sentry (`--strategy.type=sentry`)
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--dataset.repo_id=${HF_USER}/rollout_eval_data \
--dataset.single_task="Put lego brick into the box" \
--duration=3600
```
| Flag | Description |
| -------------------------------------- | ----------------------------------------------------------- |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
### Highlight (`--strategy.type=highlight`)
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
```bash
lerobot-rollout \
--strategy.type=highlight \
--strategy.ring_buffer_seconds=30 \
--strategy.save_key=s \
--strategy.push_key=h \
--policy.path=${HF_USER}/my_policy \
--robot.type=koch_follower \
--robot.port=/dev/ttyACM0 \
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
--dataset.single_task="Pick up the red cube"
```
**Keyboard controls:**
| Key | Action |
| ------------------ | -------------------------------------------------------- |
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
| `h` (configurable) | Push dataset to Hub |
| `ESC` | Stop the session |
| Flag | Description |
| -------------------------------------- | ---------------------------------------------- |
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
### DAgger (`--strategy.type=dagger`)
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
```bash
lerobot-rollout \
--strategy.type=dagger \
--strategy.record_autonomous=true \
--strategy.num_episodes=50 \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
--dataset.single_task="Grasp the block"
```
**Keyboard controls** (default input device):
| Key | Action |
| ------- | ------------------------------------------- |
| `Space` | Pause / resume policy execution |
| `Tab` | Start / stop human correction |
| `Enter` | Push dataset to Hub (corrections-only mode) |
| `ESC` | Stop the session |
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
| Flag | Description |
| ------------------------------------ | ------------------------------------------------------- |
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
---
## Inference Backends
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
### Sync (default)
One policy call per control tick. The main loop blocks until the action is computed.
Works with all policies. No extra flags needed.
### Real-Time Chunking (`--inference.type=rtc`)
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
```bash
lerobot-rollout \
--strategy.type=base \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--policy.path=${HF_USER}/pi0_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Pick up the cube" \
--duration=60 \
--device=cuda
```
| Flag | Description |
| ------------------------------------------- | -------------------------------------------------------------- |
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
---
## Common Flags
| Flag | Description | Default |
| --------------------------------- | ----------------------------------------------------------------- | ------- |
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
| `--robot.port` | Serial port for the robot | -- |
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
| `--fps` | Control loop frequency | 30 |
| `--duration` | Run time in seconds (0 = infinite) | 0 |
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
| `--task` | Task description (used when no dataset is provided) | -- |
| `--display_data` | Stream telemetry to Rerun visualization | false |
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
| `--interpolation_multiplier` | Action interpolation factor | 1 |
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
| `--resume` | Resume a previous recording session | false |
| `--play_sounds` | Vocal synthesis for events | true |
---
## Programmatic Usage
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
```python
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
cfg = RolloutConfig(
robot=my_robot_config,
policy=my_policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=30,
duration=60,
task="my task",
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=my_custom_action_processor, # optional
robot_observation_processor=my_custom_obs_processor, # optional
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
```
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
-75
View File
@@ -1,75 +0,0 @@
# Language columns and recipes
LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`.
The two optional columns are:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape:
```text
role: string
content: string | null
style: string | null
timestamp: float64
tool_calls: list[Json] | null
```
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
## Architecture
The language stack has three layers:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
## Recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
+1 -1
View File
@@ -43,7 +43,7 @@ lerobot-record \
--dataset.num_episodes=5 \
--dataset.single_task="Grab the black cube" \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--dataset.encoder_threads=2
```
+2 -2
View File
@@ -161,7 +161,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--display_data=true
```
@@ -203,7 +203,7 @@ lerobot-record \
--dataset.private=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
--display_data=true
```
+7 -18
View File
@@ -61,17 +61,6 @@ lerobot-eval \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
## Quick reference
| Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
| Goal | What to do |
| --------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
+7 -3
View File
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
lerobot-rollout \
--strategy.type=base \
--policy.path=${HF_USERNAME}/policy_repo_id \
--inference.type=rtc \
--inference.rtc.execution_horizon=10 \
--inference.rtc.max_guidance_weight=10.0 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
## References
+29 -28
View File
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
## Inputs and Targets (What the new code expects)
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
<hfoption id="single_stage">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dense_only">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
<hfoption id="dual">
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--visualize-only \
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
First, run the SARM model on all frames in your dataset to compute progress values:
```bash
python src/lerobot/policies/sarm/compute_rabc_weights.py \
python -m lerobot.rewards.sarm.compute_rabc_weights \
--dataset-repo-id your-username/your-dataset \
--reward-model-path your-username/sarm-model \
--head-mode sparse \
@@ -465,15 +465,15 @@ This script:
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
lerobot-train \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_head_mode=sparse \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.head_mode=sparse \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -488,12 +488,13 @@ The training script automatically:
**RA-BC Arguments:**
| Argument | Description | Default |
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
| Argument | Description | Default |
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------ | ------------- | ------------------------- |
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `rabc_delta_mean` | > 0 | Should be positive |
| `rabc_delta_std` | > 0 | Variance in data quality |
| Metric | Healthy Range | Problem Indicator |
| ----------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**Setting kappa based on your data:**
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
```
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
# Most deltas fall in range [0.01, 0.05]
# Option 1: Set kappa = delta_mean (medium selectivity)
--rabc_kappa=0.03
--sample_weighting.kappa=0.03
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
--rabc_kappa=0.05
--sample_weighting.kappa=0.05
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
--rabc_kappa=0.07
--sample_weighting.kappa=0.07
```
**When RA-BC may not help:**
@@ -550,8 +551,8 @@ accelerate launch \
src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -576,7 +577,7 @@ accelerate launch \
### RA-BC
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
---
+1 -1
View File
@@ -108,7 +108,7 @@ lerobot-record \
--dataset.num_episodes=10 \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.vcodec=auto \
# --dataset.camera_encoder_config.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \
+37 -27
View File
@@ -14,12 +14,22 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
## 2. Tuning Parameters
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
All encoding parameters are grouped under `camera_encoder_config` (a `VideoEncoderConfig` dataclass), accessible from the CLI via `--dataset.camera_encoder_config.<field>`.
| Parameter | CLI Flag | Type | Default | Description |
| ----------------------- | --------------------------------------------- | ------------- | ------------- | ------------------------------------------------------------------- |
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
| `vcodec` | `--dataset.camera_encoder_config.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
| `pix_fmt` | `--dataset.camera_encoder_config.pix_fmt` | `str` | `"yuv420p"` | Pixel format |
| `g` | `--dataset.camera_encoder_config.g` | `int \| None` | `2` | GOP size (keyframe interval) |
| `crf` | `--dataset.camera_encoder_config.crf` | `int \| None` | `30` | Quality level (mapped to codec-specific parameter) |
| `preset` | `--dataset.camera_encoder_config.preset` | `int \| None` | `12` | Speed preset (libsvtav1 only, 0 = slowest … 13 = fastest) |
| `fast_decode` | `--dataset.camera_encoder_config.fast_decode` | `int` | `0` | Fast-decode tuning level |
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance (global). `None` lets the codec decide |
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
> [!TIP]
> Not all parameters apply to every codec. `VideoEncoderConfig` will warn at startup if you set a parameter that your chosen codec ignores (e.g. `preset` with `h264_nvenc`).
## 3. Performance Considerations
@@ -40,7 +50,7 @@ Streaming encoding means the CPU is encoding video **during** the capture loop,
### `encoder_threads` Tuning
This parameter controls how many threads each encoder instance uses internally:
This parameter (`--dataset.encoder_threads`) controls how many threads each encoder instance uses internally:
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
@@ -82,15 +92,15 @@ Use HW encoding when:
### Available HW Encoders
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
| Encoder | Platform | Hardware | CLI Value |
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- |
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=h264_videotoolbox` |
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=hevc_videotoolbox` |
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=h264_nvenc` |
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=hevc_nvenc` |
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder_config.vcodec=h264_vaapi` |
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder_config.vcodec=h264_qsv` |
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder_config.vcodec=auto` |
> [!NOTE]
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
@@ -100,15 +110,15 @@ Use HW encoding when:
## 5. Troubleshooting
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
| Symptom | Likely Cause | Fix |
| ------------------------------------------------------------------ | -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder_config.vcodec=auto`) |
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder_config.vcodec=auto`). |
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder_config.vcodec=auto` |
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
## 6. Recommended Configurations
@@ -146,10 +156,10 @@ On very constrained systems, streaming encoding may compete too heavily with the
# 2camsx 640x480x3 @30fps: Requires some tuning.
# Use H.264, disable streaming, consider batching encoding
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
lerobot-record --dataset.camera_encoder_config.vcodec=h264 --dataset.streaming_encoding=false ...
```
## 7. Closing note
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
`camera_encoder_config.vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
+3 -2
View File
@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
Once trained, we recommend deploying policies using inference-time RTC:
```bash
python examples/rtc/eval_with_real_robot.py \
lerobot-rollout \
--strategy.type=base \
--policy.path=your-username/your-repo-id \
--policy.device=cuda \
--robot.type=unitree_g1 \
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
--task="task_description" \
--duration=1000 \
--fps=30 \
--rtc.enabled=true
--inference.type=rtc
```
---
+12 -9
View File
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_image_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
--operation.camera_encoder_config.vcodec libsvtav1 \
--operation.camera_encoder_config.pix_fmt yuv420p \
--operation.camera_encoder_config.g 2 \
--operation.camera_encoder_config.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
@@ -147,11 +147,14 @@ lerobot-edit-dataset \
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `camera_encoder_config`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder_config.<field>`:
- `vcodec`: Video codec — `h264`, `hevc`, `libsvtav1`, `auto`, or hardware codecs (default: `libsvtav1`)
- `pix_fmt`: Pixel format — `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: GOP size — lower values give better quality but larger files (default: 2)
- `crf`: Quality level — lower is better, 0 is lossless (default: 30)
- `preset`: Speed preset, libsvtav1 only (default: 12)
- `fast_decode`: Fast-decode tuning (default: 0)
- `encoder_threads`: Threads per encoder instance — global setting, separate from `camera_encoder_config` (default: None)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
+4 -4
View File
@@ -220,7 +220,7 @@ REAL_DIM = 12
# Postprocessing: Trim 20D predictions to 12D for deployment
```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
#### Auto Action Mode (Recommended)
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
## Contributing
+1 -1
View File
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
from lerobot.rewards.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
File diff suppressed because it is too large Load Diff
-226
View File
@@ -1,226 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for Human-in-the-Loop data collection scripts."""
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common.control_utils import is_headless
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
RobotObservation,
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
@dataclass
class HILDatasetConfig:
repo_id: str
single_task: str
root: str | Path | None = None
fps: int = 30
episode_time_s: float = 120
num_episodes: int = 50
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
"""Check if teleoperator has motor control capabilities."""
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
def teleop_disable_torque(teleop: Teleoperator) -> None:
"""Disable teleop torque if supported."""
if hasattr(teleop, "disable_torque"):
teleop.disable_torque()
def teleop_enable_torque(teleop: Teleoperator) -> None:
"""Enable teleop torque if supported."""
if hasattr(teleop, "enable_torque"):
teleop.enable_torque()
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
"""Smoothly move teleop to target position if motor control is available."""
if not teleop_has_motor_control(teleop):
logger.warning("Teleop does not support motor control - cannot mirror robot position")
return
teleop_enable_torque(teleop)
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {}
for k in current:
if k in target_pos:
interp[k] = current[k] * (1 - t) + target_pos[k] * t
else:
interp[k] = current[k]
teleop.write_goal_positions(interp)
time.sleep(1 / fps)
def init_keyboard_listener():
"""Initialize keyboard listener with HIL controls."""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
"policy_paused": False,
"correction_active": False,
"resume_policy": False,
"in_reset": False,
"start_next_episode": False,
}
if is_headless():
logger.warning("Headless environment - keyboard controls unavailable")
return None, events
from pynput import keyboard
def on_press(key):
try:
if events["in_reset"]:
if key in [keyboard.Key.space, keyboard.Key.right]:
logger.info("[HIL] Starting next episode...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "c":
events["start_next_episode"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
events["stop_recording"] = True
events["start_next_episode"] = True
else:
if key == keyboard.Key.space:
if not events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
events["policy_paused"] = True
elif hasattr(key, "char") and key.char == "c":
if events["policy_paused"] and not events["correction_active"]:
logger.info("[HIL] Taking control...")
events["start_next_episode"] = True
elif hasattr(key, "char") and key.char == "p":
if events["policy_paused"] or events["correction_active"]:
logger.info("[HIL] Resuming policy...")
events["resume_policy"] = True
elif key == keyboard.Key.right:
logger.info("[HIL] End episode")
events["exit_early"] = True
elif key == keyboard.Key.left:
logger.info("[HIL] Re-record episode")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
logger.info("[HIL] ESC - Stop recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
logger.info(f"Key error: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def make_identity_processors():
"""Create identity processors for recording."""
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[IdentityProcessorStep()],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[IdentityProcessorStep()],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
return teleop_proc, obs_proc
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
"""Reset period where human repositions environment."""
logger.info("[HIL] RESET")
events["in_reset"] = True
events["start_next_episode"] = False
obs = robot.get_observation()
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
logger.info("Press any key to enable teleoperation")
while not events["start_next_episode"] and not events["stop_recording"]:
precise_sleep(0.05)
if events["stop_recording"]:
return
events["start_next_episode"] = False
teleop_disable_torque(teleop)
logger.info("Teleop enabled - press any key to start episode")
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
action = teleop.get_action()
robot.send_action(action)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
events["start_next_episode"] = False
events["exit_early"] = False
events["policy_paused"] = False
events["correction_active"] = False
events["resume_policy"] = False
def print_controls(rtc: bool = False):
"""Print control instructions."""
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
logger.info(
"%s\n Controls:\n"
" SPACE - Pause policy\n"
" c - Take control\n"
" p - Resume policy after pause/correction\n"
" → - End episode\n"
" ESC - Stop and push to hub",
mode,
)
+62 -31
View File
@@ -14,17 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.common.control_utils import init_keyboard_listener
import logging
import time
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 2
FPS = 30
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
@@ -83,43 +90,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot
robot_action_to_send = robot_action_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
+10 -9
View File
@@ -45,9 +45,6 @@ def main():
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
@@ -77,6 +74,10 @@ def main():
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
teleop_action_processor, robot_action_processor, robot_observation_processor = (
make_default_processors()
)
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
@@ -87,14 +88,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
@@ -106,13 +107,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
+77
View File
@@ -0,0 +1,77 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained policy on LeKiwi without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). For a CLI entry point with the same capabilities plus
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
"""
from lerobot.configs import PreTrainedConfig
from lerobot.robots.lekiwi import LeKiwiClientConfig
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
# by ``build_rollout_context`` to reload the full model.
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
# Assemble the rollout config: base strategy (no recording) + sync inference.
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
signal_handler = ProcessSignalHandler(use_threads=True)
# Build the context (connects robot, loads policy, wires the inference strategy).
# No custom processors here — LeKiwi runs on raw joint features.
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+63 -32
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5
FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+13 -13
View File
@@ -65,14 +65,15 @@ def main():
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to EE action
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction
](
@@ -94,7 +95,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
# Build pipeline to convert EE action to joints action (IK).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
@@ -107,7 +108,7 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to EE observation
# Build pipeline to convert joint observation to EE observation (FK).
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -118,13 +119,12 @@ def main():
to_output=transition_to_observation,
)
# Create the dataset
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
@@ -163,14 +163,14 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
@@ -182,13 +182,13 @@ def main():
robot=robot,
events=events,
fps=FPS,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
+126
View File
@@ -0,0 +1,126 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
with phone teleoperation in EE space, so at deployment we only need the
joint↔EE conversion on the robot side; the phone is not used.
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
(inline policy call). For recording during rollout, switch to Sentry,
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Peek at motor names once to build the kinematic solver.
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
-673
View File
@@ -1,673 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=<USER>/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot-data-collection/folding_final \
--robot.type=bi_openarm_follower \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
--robot.left_arm_config.port=can0 \
--robot.left_arm_config.side=left \
--robot.left_arm_config.can_interface=socketcan \
--robot.left_arm_config.disable_torque_on_disconnect=true \
--robot.left_arm_config.max_relative_target=8.0 \
--robot.right_arm_config.port=can1 \
--robot.right_arm_config.side=right \
--robot.right_arm_config.can_interface=socketcan \
--robot.right_arm_config.disable_torque_on_disconnect=true \
--robot.right_arm_config.max_relative_target=8.0 \
--task="Fold the T-shirt properly" \
--fps=30 \
--duration=2000 \
--interpolation_multiplier=3 \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR \
--device=cuda
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
NormalizerProcessorStep,
RelativeActionsProcessorStep,
TransitionKey,
create_transition,
make_default_robot_action_processor,
make_default_robot_observation_processor,
to_relative_actions,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_openarm_follower,
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def _reanchor_relative_rtc_prefix(
prev_actions_absolute: Tensor,
current_state: Tensor,
relative_step: RelativeActionsProcessorStep,
normalizer_step: NormalizerProcessorStep | None,
policy_device: torch.device | str,
) -> Tensor:
"""Convert absolute leftovers into model-space for relative-action RTC policies.
When a policy uses relative actions, the RTC prefix (leftover actions from
the previous chunk) is stored in absolute space. Before feeding it back to
the policy we need to re-express it relative to the *current* robot state
and then re-normalize.
"""
state = current_state.detach().cpu()
if state.dim() == 1:
state = state.unsqueeze(0)
action_cpu = prev_actions_absolute.detach().cpu()
mask = relative_step._build_mask(action_cpu.shape[-1])
relative_actions = to_relative_actions(action_cpu, state, mask)
transition = create_transition(action=relative_actions)
if normalizer_step is not None:
transition = normalizer_step(transition)
return transition[TransitionKey.ACTION].to(policy_device)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
# Only keep .pos joints + camera streams if the policy was trained on positions,
# not the full pos/vel/torque state the robot exposes.
observation_features_hw = {
key: value
for key, value in robot.observation_features().items()
if key.endswith(".pos") or isinstance(value, tuple)
}
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
normalizer_step = next(
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
None,
)
if relative_step is not None:
if relative_step.action_names is None:
cfg_names = getattr(cfg.policy, "action_feature_names", None)
if cfg_names:
relative_step.action_names = list(cfg_names)
else:
relative_step.action_names = [
k for k in robot.robot.action_features if k.endswith(".pos")
]
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Re-anchor leftover actions for relative-action policies.
# We need the *postprocessed* (absolute) leftover, not the original
# (normalized/relative) one that get_left_over() returns.
if (
prev_actions is not None
and relative_step is not None
and OBS_STATE in obs_with_policy_features
):
with action_queue.lock:
if action_queue.queue is not None:
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
else:
prev_actions_abs = None
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_actions_abs,
current_state=obs_with_policy_features[OBS_STATE],
relative_step=relative_step,
normalizer_step=normalizer_step,
policy_device=policy_device,
)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
action_count = 0
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
action_interval = interpolator.get_control_interval(cfg.fps)
while not shutdown_event.is_set():
start_time = time.perf_counter()
if interpolator.needs_new_action():
new_action = action_queue.get()
if new_action is not None:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
if config.use_peft:
from peft import PeftConfig, PeftModel
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")
+63 -32
View File
@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
RobotProcessorPipeline,
make_default_teleop_action_processor,
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
NUM_EPISODES = 5
FPS = 30
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
def main():
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
# This script provides a self-contained example for educational purposes.
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
@@ -143,43 +151,67 @@ def main():
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
control_interval = 1 / FPS
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Inline evaluation loop: predict actions and send to robot
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < EPISODE_TIME_SEC:
start_loop_t = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
# Get robot observation
obs = robot.get_observation()
obs_processed = robot_joints_to_ee_pose_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Predict action using the policy
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=policy.config.device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.device.type == "cuda",
task=TASK_DESCRIPTION,
robot_type=robot.name,
)
# Convert policy output to robot action dict
action_values = make_robot_action(action_tensor, dataset.features)
# Process and send action to robot (EE -> joints via IK)
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
robot.send_action(robot_action_to_send)
# Write to dataset
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
sleep_time_s = control_interval - dt_s
if sleep_time_s < 0:
logging.warning(
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
)
precise_sleep(max(sleep_time_s, 0.0))
timestamp = time.perf_counter() - start_episode_t
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
log_say("Waiting for environment reset, press right arrow key when ready...")
if events["rerecord_episode"]:
log_say("Re-record episode")
@@ -190,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+15 -17
View File
@@ -62,21 +62,20 @@ def main():
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# Build pipeline to convert follower joints to EE observation
# Build pipeline to convert follower joints to EE observation.
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
@@ -87,7 +86,7 @@ def main():
to_output=transition_to_observation,
)
# Build pipeline to convert leader joints to EE action
# Build pipeline to convert leader joints to EE action.
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
ForwardKinematicsJointsToEE(
@@ -98,9 +97,9 @@ def main():
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to follower joints
# Build pipeline to convert EE action to follower joints (with safety bounds).
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
steps=[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
@@ -115,13 +114,12 @@ def main():
to_output=transition_to_robot_action,
)
# Create the dataset
# Create the dataset, deriving features from the pipelines so the on-disk schema
# matches exactly what the pipelines produce at runtime.
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features),
@@ -144,7 +142,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
init_rerun(session_name="recording_so100_ee")
try:
if not leader.is_connected or not follower.is_connected:
@@ -160,14 +158,14 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
@@ -179,13 +177,13 @@ def main():
robot=follower,
events=events,
fps=FPS,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
+134
View File
@@ -0,0 +1,134 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a trained EE-space policy on SO100 without recording (base rollout).
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
control tick). The custom observation/action processors convert between
joint space (robot hardware) and end-effector space (policy I/O) via
forward/inverse kinematics.
"""
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.configs import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
RobotProcessorPipeline,
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
from lerobot.rollout.inference import SyncInferenceConfig
from lerobot.rollout.strategies import BaseStrategy
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
FPS = 30
DURATION_SEC = 60
TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
def main():
init_logging()
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
# Kinematic solver: we need the motor-name list, so peek at the robot once.
# (The rollout engine owns the connected instance; we only use this for introspection.)
temp_robot = SO100Follower(robot_config)
motor_names = list(temp_robot.bus.motors.keys())
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=motor_names,
)
# Joint-space observation → EE-space observation (consumed by the policy).
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# EE-space action (produced by the policy) → joint-space action (sent to robot).
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=motor_names,
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Policy config (full model is loaded inside build_rollout_context).
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
cfg = RolloutConfig(
robot=robot_config,
policy=policy_config,
strategy=BaseStrategyConfig(),
inference=SyncInferenceConfig(),
fps=FPS,
duration=DURATION_SEC,
task=TASK_DESCRIPTION,
)
signal_handler = ProcessSignalHandler(use_threads=True)
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
# otherwise skip the joint↔EE conversion and the policy would receive the
# wrong observation/action space.
ctx = build_rollout_context(
cfg,
signal_handler.shutdown_event,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
strategy = BaseStrategy(cfg.strategy)
try:
strategy.setup(ctx)
strategy.run(ctx)
finally:
strategy.teardown(ctx)
if __name__ == "__main__":
main()
+1 -1
View File
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -1,7 +1,7 @@
import torch
from lerobot.datasets import LeRobotDataset
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
def main():
@@ -22,10 +22,10 @@ def main():
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
# Make reward model, preprocessor, and optimizer
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
classifier_id = "<user>/reward_classifier_hil_serl_example"
@@ -42,7 +42,7 @@ def main():
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
loss, output_dict = reward_model.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
@@ -58,8 +58,8 @@ def main():
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)
# You can now save the trained reward model.
reward_model.push_to_hub(classifier_id)
if __name__ == "__main__":
+2 -11
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<5.0.0",
"datasets>=4.0.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -200,15 +200,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Annotation pipeline (lerobot-annotate). datatrove is mandatory; vllm is
# the preferred backend on Linux, with a transformers fallback elsewhere.
annotations = [
"lerobot[dataset]",
"lerobot[transformers-dep]",
"datatrove>=0.4.0,<2.0.0",
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
@@ -298,7 +289,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
-15
View File
@@ -1,15 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -1,36 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Steerable annotation pipeline producing ``language_persistent`` and
``language_events`` columns for LeRobot datasets.
The pipeline is decomposed into three independently runnable modules whose
outputs are staged per-episode before a final parquet rewrite:
- :mod:`.modules.plan_subtasks_memory` (Module 1) persistent styles
- :mod:`.modules.interjections_and_speech` (Module 2) event styles + speech
- :mod:`.modules.general_vqa` (Module 3) event-style VQA pairs
"""
from .config import AnnotationPipelineConfig
from .validator import StagingValidator, ValidationReport
from .writer import LanguageColumnsWriter
__all__ = [
"AnnotationPipelineConfig",
"LanguageColumnsWriter",
"StagingValidator",
"ValidationReport",
]
@@ -1,168 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class Module1Config:
"""Module 1 hyperparameters: plan + subtasks + memory.
Subtask decomposition sees the **whole episode** as one Qwen-VL video
block no keyframe stride or count: the model handles temporal pooling
itself and decides where to cut. ``max_video_frames`` only caps the
number of frames packed into the video block (a model-capacity bound,
not an annotation-logic knob).
"""
enabled: bool = True
max_video_frames: int = 32
min_subtask_seconds: float = 1.5
plan_max_steps: int = 8
use_video_url: bool = False
"""When True (and backend supports it, e.g. ``openai``), Module 1
sends a ``video_url`` content block pointing at the episode's mp4
file instead of pre-decoded frames. Lets the server sample frames at
its own ``fps`` no in-process conv3d cost. The video file is
extracted as a per-episode subclip to ``staging/.video_clips/`` so
the model sees only this episode's frames."""
use_video_url_fps: float = 1.0
"""Frame-rate hint to send to the server (mm_processor_kwargs.fps).
Only used when ``use_video_url=True``. ``1.0`` = sample 1 frame per
second, which is plenty for subtask-boundary detection on most
manipulation episodes."""
@dataclass
class Module2Config:
"""Module 2 hyperparameters: interjections + paired speech."""
enabled: bool = True
max_interjections_per_episode: int = 1
interjection_min_t: float = 2.0
@dataclass
class Module3Config:
"""Module 3 hyperparameters: general VQA."""
enabled: bool = True
vqa_emission_hz: float = 1.0
K: int = 3
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
@dataclass
class VlmConfig:
"""Shared Qwen-VL client configuration."""
backend: str = "vllm"
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
The ``openai`` backend talks to any OpenAI-compatible server works
with ``vllm serve``, ``transformers serve``, ``ktransformers serve``,
or hosted endpoints. Set ``api_base`` and (optionally) ``api_key``."""
model_id: str = "Qwen/Qwen3.6-27B-FP8"
api_base: str = "http://localhost:8000/v1"
"""Base URL for the ``openai`` backend."""
api_key: str = "EMPTY"
"""API key for the ``openai`` backend; ``EMPTY`` works for local servers."""
auto_serve: bool = True
"""When True with ``backend=openai``, the CLI probes ``api_base``
first; if no server answers, it spawns one (default:
``transformers serve``), waits for it to be ready, runs the
pipeline, and tears it down on exit. Default ``True`` so a single
``lerobot-annotate`` call can drive the whole flow. Set to ``False``
if you want to fail fast when no server is reachable (e.g. you're
pointing at a remote endpoint that should already be up)."""
serve_port: int = 8000
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
serve_command: str | None = None
"""Override the auto-serve command (full shell command). When ``None``,
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``."""
serve_ready_timeout_s: float = 600.0
"""Max seconds to wait for the server to start serving requests."""
max_new_tokens: int = 512
temperature: float = 0.2
json_mode: bool = True
batch_size: int = 4
tensor_parallel_size: int = 1
gpu_memory_utilization: float = 0.9
"""Fraction of GPU memory vllm allocates for weights + KV cache.
Lower (e.g. 0.7) when the vision encoder needs cuDNN workspace, or to
avoid CUDNN_STATUS_NOT_INITIALIZED on tight VRAM (30B BF16 on 80 GB)."""
max_model_len: int | None = None
"""Cap context length. ``None`` keeps the model's default; on H100 80 GB
a 30B BF16 model often needs ``max_model_len=8192`` or smaller to leave
room for KV cache."""
trust_remote_code: bool = False
"""Pass ``trust_remote_code`` to HF auto-classes. Default ``False`` —
only enable for models that actually ship custom code in their repo
(rare for first-class VL releases). On Qwen3-VL it triggers an
std::bad_alloc post-load even though the official transformers class
is sufficient, so leaving this off is safest."""
camera_key: str | None = None
"""Override the camera stream used for keyframe attachment. ``None`` picks
the first ``observation.images.*`` key the dataset declares."""
@dataclass
class ExecutorConfig:
"""Executor selection and SLURM hyperparameters."""
auto_threshold: int = 32
force_local: bool = False
slurm_partition: str | None = None
slurm_gpus: int = 1
slurm_time: str = "06:00:00"
workers: int = 1
@dataclass
class AnnotationPipelineConfig:
"""Top-level config for ``lerobot-annotate``.
Mirrors the structure of :class:`lerobot.configs.train.TrainPipelineConfig`:
a draccus-parsed dataclass that contains nested per-module sub-configs and
leaves the dataset, executor, and VLM choices independently knobbable.
Output is always in-place: the writer rewrites ``data/chunk-*/file-*.parquet``
in place. Multiple revisions of the same dataset live in separate copies.
"""
repo_id: str | None = None
root: Path | None = None
staging_dir: Path | None = None
"""If unset, defaults to ``<root>/.annotate_staging/``."""
seed: int = 1729
module_1: Module1Config = field(default_factory=Module1Config)
module_2: Module2Config = field(default_factory=Module2Config)
module_3: Module3Config = field(default_factory=Module3Config)
vlm: VlmConfig = field(default_factory=VlmConfig)
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None
def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -1,163 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Executor selection: local vs SLURM via datatrove.
The executor plans **four phases** with the dependency order from the plan:
phase 1: Module 1 (plan + subtasks + memory)
phase 2: Module 2 (interjections + speech)
phase 3: Module 1 plan-update pass re-runs plan emission at every
interjection timestamp produced by phase 2
phase 4: Module 3 (VQA)
phase 5: validator
phase 6: writer
Phase 3 is why ``executor.py`` documents the dependency: Module 1 must be
re-entered after Module 2 to refresh ``plan`` rows at interjection times.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .config import AnnotationPipelineConfig, ExecutorConfig
from .reader import EpisodeRecord, iter_episodes
from .staging import EpisodeStaging
from .validator import StagingValidator
from .writer import LanguageColumnsWriter
logger = logging.getLogger(__name__)
@dataclass
class PhaseResult:
"""Summary of one pipeline phase across all episodes."""
name: str
episodes_processed: int
episodes_skipped: int
@dataclass
class PipelineRunSummary:
"""Aggregated result returned by :meth:`Executor.run`."""
phases: list[PhaseResult]
written_paths: list[Path]
validation_report: Any # ValidationReport, kept Any to avoid import cycle
def select_executor_class(num_episodes: int, config: ExecutorConfig) -> str:
"""Return ``"local"`` or ``"slurm"`` based on the threshold.
The plan's "executor selection threshold" lives in
:class:`ExecutorConfig.auto_threshold`. ``force_local`` always wins.
"""
if config.force_local:
return "local"
return "local" if num_episodes <= config.auto_threshold else "slurm"
@dataclass
class Executor:
"""Run all four phases over a dataset root.
The executor is intentionally framework-agnostic: by default it runs the
phases inline (suitable for tests, small datasets, and the CLI's
``--force-local`` mode). It will optionally hand off to datatrove's
:class:`LocalPipelineExecutor` or :class:`SlurmPipelineExecutor` when those
are installed and the dataset is large enough to benefit from them.
Tests construct the executor directly with stub modules.
"""
config: AnnotationPipelineConfig
module_1: Any # PlanSubtasksMemoryModule
module_2: Any # InterjectionsAndSpeechModule
module_3: Any # GeneralVqaModule
writer: LanguageColumnsWriter
validator: StagingValidator
def run(self, root: Path) -> PipelineRunSummary:
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
n = len(records)
if n == 0:
raise ValueError(f"No episodes found under {root}/data/")
executor_kind = select_executor_class(n, self.config.executor)
logger.info("annotate: %d episodes; executor=%s", n, executor_kind)
staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True)
phases: list[PhaseResult] = []
# Phase 1: Module 1 (plan + subtasks + memory)
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
# Phase 2: Module 2 (interjections + speech)
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
# Phase 3: Module 1 plan-update pass at interjection timestamps.
phases.append(self._run_plan_update_phase(records, staging_dir))
# Phase 4: Module 3 (VQA)
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}")
written = self.writer.write_all(records, staging_dir, root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
def _run_module_phase(
self,
name: str,
records: list[EpisodeRecord],
staging_dir: Path,
module: Any,
) -> PhaseResult:
if not module.enabled:
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging)
processed += 1
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult:
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
therefore calls back into Module 1 with the interjection timestamps so
Module 1's existing prompt path is reused.
"""
if not self.module_1.enabled or not self.module_2.enabled:
return PhaseResult(
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
)
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)
interjection_times = [
row["timestamp"] for row in staging.read("module_2") if row.get("style") == "interjection"
]
if interjection_times:
self.module_1.run_plan_updates(record, staging, interjection_times)
processed += 1
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
@@ -1,264 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keyframe extraction for the annotation pipeline.
Modules attach decoded camera frames to their VLM prompts so the model can
ground subtask decomposition, interjection scenarios, and VQA in actual
visual content. The pipeline shares one provider across modules and one
episode at a time, with a small per-episode cache so multiple modules
querying the same timestamp pay decode cost once.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
from .reader import EpisodeRecord
class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps."""
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
"""Return one PIL.Image per timestamp; empty list if no camera available."""
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
"""Return up to ``max_frames`` PIL images covering the whole episode.
Sampling is uniform across the episode duration. The returned list is
intended to be passed as one ``{"type":"video", "video":<list>}``
block to a Qwen-VL-compatible model that pools temporally itself.
Empty list if no camera available.
"""
@dataclass
class _NullProvider:
"""No-op provider used when the dataset has no video keys or in tests."""
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
return []
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
return []
def null_provider() -> FrameProvider:
return _NullProvider()
@dataclass
class VideoFrameProvider:
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
The first camera key is used unconditionally Module 1/2/3 prompts care
about *what is happening*, not which camera angle the model sees, so a
single canonical viewpoint is enough. Override ``camera_key`` if you
want a specific stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped Module 2 + Module 1 plan-update calls cheap.
"""
root: Path
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
def __post_init__(self) -> None:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
if self.camera_key is None:
keys = self._meta.video_keys
self.camera_key = keys[0] if keys else None
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
if not timestamps or self.camera_key is None:
return []
out: list[Any] = []
misses: list[float] = []
miss_indices: list[int] = []
for i, ts in enumerate(timestamps):
key = (record.episode_index, round(float(ts), 6))
cached = self._cache.get(key)
if cached is not None:
out.append(cached)
else:
out.append(None)
misses.append(float(ts))
miss_indices.append(i)
if misses:
decoded = self._decode(record.episode_index, misses)
# decoder may return fewer frames than requested when some
# timestamps fall outside the video; pair what we have and
# leave the rest as None to be filtered below.
for i, img in zip(miss_indices, decoded):
out[i] = img
key = (record.episode_index, round(float(timestamps[i]), 6))
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = img
# filter out any None left over from decode failures
return [img for img in out if img is not None]
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
import os as _os # noqa: PLC0415
from PIL import Image # noqa: PLC0415
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
# some environments; default to ``pyav`` (always available via
# the ``av`` package) and let users override with
# LEROBOT_VIDEO_BACKEND=torchcodec when their stack supports it.
backend = _os.environ.get("LEROBOT_VIDEO_BACKEND", "pyav")
try:
frames = decode_video_frames(
video_path,
shifted,
self.tolerance_s,
backend=backend,
return_uint8=True,
)
except Exception:
return []
# frames: [N, C, H, W] uint8, RGB
out: list[Any] = []
arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames
for i in range(arr.shape[0]):
chw = arr[i]
hwc = chw.transpose(1, 2, 0)
out.append(Image.fromarray(hwc, mode="RGB"))
return out
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
"""Return up to ``max_frames`` images uniformly sampled across the episode.
The whole episode duration is covered; the model picks subtask
boundaries from the temporal pooling it does internally.
"""
if max_frames <= 0 or self.camera_key is None or not record.frame_timestamps:
return []
n_frames = min(max_frames, len(record.frame_timestamps))
if n_frames == len(record.frame_timestamps):
timestamps = list(record.frame_timestamps)
else:
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
if t_last <= t0:
timestamps = [float(t0)] * n_frames
else:
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
timestamps = [float(t0 + i * step) for i in range(n_frames)]
return self.frames_at(record, timestamps)
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
try:
provider = VideoFrameProvider(root=root, camera_key=camera_key)
except Exception:
return null_provider()
if provider.camera_key is None:
return null_provider()
return provider
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
"""Convert PIL images to Qwen-VL-compatible content blocks."""
return [{"type": "image", "image": img} for img in images]
def to_video_block(images: list[Any]) -> list[dict[str, Any]]:
"""Wrap a list of PIL images as one Qwen-VL video block.
Returns ``[]`` when the list is empty, so the caller can splat the result
into a content array without a separate emptiness check.
"""
if not images:
return []
return [{"type": "video", "video": list(images)}]
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
"""Wrap a video file URL as one ``video_url`` block.
Used by the ``openai`` backend (transformers serve / vllm serve /
ktransformers serve), where the server handles frame sampling.
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
"""
if not url:
return []
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
def episode_clip_path(
record: EpisodeRecord,
provider: "VideoFrameProvider",
cache_dir: Path,
) -> Path | None:
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
Returns ``None`` if the dataset has no video tracks. Skips re-extract
when the cached clip already exists. Uses ``ffmpeg`` via subprocess
with stream-copy where possible (no re-encode) for speed.
"""
import subprocess # noqa: PLC0415
if provider.camera_key is None:
return None
cache_dir.mkdir(parents=True, exist_ok=True)
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
if out_path.exists() and out_path.stat().st_size > 0:
return out_path
ep = provider._meta.episodes[record.episode_index]
from_timestamp = float(ep[f"videos/{provider.camera_key}/from_timestamp"])
to_timestamp = float(ep[f"videos/{provider.camera_key}/to_timestamp"])
src = provider.root / provider._meta.get_video_file_path(
record.episode_index, provider.camera_key
)
cmd = [
"ffmpeg",
"-y",
"-loglevel",
"error",
"-ss",
f"{from_timestamp:.3f}",
"-to",
f"{to_timestamp:.3f}",
"-i",
str(src),
"-c",
"copy",
str(out_path),
]
try:
subprocess.run(cmd, check=True, timeout=120)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
return None
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
@@ -1,25 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .general_vqa import GeneralVqaModule
from .interjections_and_speech import InterjectionsAndSpeechModule
from .plan_subtasks_memory import PlanSubtasksMemoryModule
__all__ = [
"GeneralVqaModule",
"InterjectionsAndSpeechModule",
"PlanSubtasksMemoryModule",
]
@@ -1,152 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 3: general VQA at a timed cadence.
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
emission so each frame gets at most one ``(vqa, user)`` and one
``(vqa, assistant)`` pair keeps the resolver contract scalar.
Question types covered (per the plan's Module 3 table): bbox, keypoint,
count, attribute, spatial. The assistant's ``content`` is a JSON string
whose schema depends on the question type. Malformed JSON triggers one
retry inside :meth:`VlmClient.generate_json`.
"""
from __future__ import annotations
import json
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import Module3Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..validator import classify_vqa_answer
from ..vlm_client import VlmClient
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
"""Return the relative frame indices to anchor VQA emissions to.
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
consecutive frames starting at the tick. Ticks fall on the nearest
available source frame timestamp.
"""
if hz <= 0 or k <= 0 or not frame_timestamps:
return []
t0 = frame_timestamps[0]
t_last = frame_timestamps[-1]
period = 1.0 / hz
indices: list[int] = []
t = t0
while t <= t_last + 1e-9:
# find the index of the nearest frame to t
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
for offset in range(k):
j = nearest_i + offset
if j >= len(frame_timestamps):
break
if not indices or indices[-1] != j:
indices.append(j)
t += period
# dedupe while preserving order
seen: set[int] = set()
deduped: list[int] = []
for i in indices:
if i in seen:
continue
seen.add(i)
deduped.append(i)
return deduped
@dataclass
class GeneralVqaModule:
"""Emit grounded VQA pairs at a timed cadence."""
vlm: VlmClient
config: Module3Config
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
if not record.frame_timestamps:
staging.write("module_3", [])
return
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
rows: list[dict[str, Any]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype, ts)
if qa is None:
continue
question, answer = qa
rows.append(
{
"role": "user",
"content": question,
"style": "vqa",
"timestamp": ts,
"tool_calls": None,
}
)
rows.append(
{
"role": "assistant",
"content": json.dumps(answer, sort_keys=True),
"style": "vqa",
"timestamp": ts,
"tool_calls": None,
}
)
staging.write("module_3", rows)
def _generate_one(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(record, [frame_timestamp])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
return None
question = result.get("question")
answer = result.get("answer")
if not isinstance(question, str) or not question.strip():
return None
if not isinstance(answer, dict):
return None
# The validator will enforce shape; here we just sanity-check that the
# answer matches *some* known shape so we can drop garbage early.
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
@@ -1,133 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
Two sub-passes:
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
canonical task). No interjection row the canonical task is already the
user utterance from ``meta/tasks.parquet``.
2. For mid-episode interruptions, emit a co-timestamped pair:
{role:user, style:interjection, content:<text>}
speech atom (role:assistant, style:None, tool_calls=[say(...)])
Both rows go in ``language_events`` at the same timestamp.
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
timestamps to refresh the ``plan`` row at the same instant.
"""
from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from ..config import Module2Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
if not frame_timestamps:
return float(t)
return float(min(frame_timestamps, key=lambda f: abs(f - t)))
@dataclass
class InterjectionsAndSpeechModule:
"""Generate task-start speech and mid-episode interjection/speech pairs."""
vlm: VlmClient
config: Module2Config
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
if record.frame_timestamps:
t0 = float(record.frame_timestamps[0])
initial = self._initial_speech(record)
if initial:
rows.append(speech_atom(t0, initial))
rows.extend(self._mid_episode_interjections(record))
staging.write("module_2", rows)
def _initial_speech(self, record: EpisodeRecord) -> str | None:
prompt = load_prompt("module_2_initial_speech").format(
episode_task=record.episode_task,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("text"), str):
text = result["text"].strip()
if text:
return text
return None
def _mid_episode_interjections(self, record: EpisodeRecord) -> list[dict[str, Any]]:
if self.config.max_interjections_per_episode <= 0:
return []
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
candidate_ts = [t for t in record.frame_timestamps if t >= self.config.interjection_min_t]
if not candidate_ts:
return []
n = min(self.config.max_interjections_per_episode, len(candidate_ts) // 4)
if n <= 0:
return []
chosen = sorted(rng.sample(candidate_ts, n))
out: list[dict[str, Any]] = []
for t in chosen:
t_snap = _snap_to_frame(t, record.frame_timestamps)
current_subtask = record.episode_task
prompt = load_prompt("module_2_interjection").format(
episode_task=record.episode_task,
current_subtask=current_subtask,
timestamp=t_snap,
)
images = self.frame_provider.frames_at(record, [t_snap])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
interjection_text = result.get("interjection")
speech_text = result.get("speech")
if not isinstance(interjection_text, str) or not interjection_text.strip():
continue
if not isinstance(speech_text, str) or not speech_text.strip():
continue
out.append(
{
"role": "user",
"content": interjection_text.strip(),
"style": "interjection",
"timestamp": t_snap,
"tool_calls": None,
}
)
out.append(speech_atom(t_snap, speech_text.strip()))
return out
@@ -1,249 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from pathlib import Path
from ..config import Module1Config
from ..frames import (
FrameProvider,
VideoFrameProvider,
episode_clip_path,
null_provider,
to_video_block,
to_video_url_block,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp."""
if not frame_timestamps:
return float(t)
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
return float(nearest)
@dataclass
class PlanSubtasksMemoryModule:
"""Generate subtask spans, plan, and memory rows.
All output is persistent (lives in ``language_persistent``):
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
(snapped to an exact frame).
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
timestamp via :meth:`run_plan_updates` (called by the executor after
Module 2 completes).
- ``memory`` rows: emitted at each subtask boundary (= subtask start
timestamp from the second subtask onward).
"""
vlm: VlmClient
config: Module1Config
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
return self.config.enabled
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
rows: list[dict[str, Any]] = []
subtask_spans = self._generate_subtasks(record)
# subtask rows
for span in subtask_spans:
rows.append(
{
"role": "assistant",
"content": span["text"],
"style": "subtask",
"timestamp": _snap_to_frame(span["start"], record.frame_timestamps),
"tool_calls": None,
}
)
# plan row at t=0
plan_text = self._generate_plan(record, subtask_spans)
if plan_text is not None:
t0 = record.frame_timestamps[0] if record.frame_timestamps else 0.0
rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": float(t0),
"tool_calls": None,
}
)
# memory rows at every subtask boundary except the very first start
prior_memory = ""
for i, span in enumerate(subtask_spans[1:], start=1):
completed = subtask_spans[i - 1]["text"]
remaining = [s["text"] for s in subtask_spans[i:]]
mem_text = self._generate_memory(record, prior_memory, completed, remaining)
if mem_text:
ts = _snap_to_frame(span["start"], record.frame_timestamps)
rows.append(
{
"role": "assistant",
"content": mem_text,
"style": "memory",
"timestamp": ts,
"tool_calls": None,
}
)
prior_memory = mem_text
staging.write("module_1", rows)
def run_plan_updates(
self,
record: EpisodeRecord,
staging: EpisodeStaging,
interjection_times: Sequence[float],
) -> None:
"""Append additional ``plan`` rows at every interjection timestamp."""
existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing)
new_rows = list(existing)
for raw_t in interjection_times:
t = _snap_to_frame(raw_t, record.frame_timestamps)
plan_text = self._generate_plan(record, spans, refresh_t=t)
if plan_text is not None:
new_rows.append(
{
"role": "assistant",
"content": plan_text,
"style": "plan",
"timestamp": t,
"tool_calls": None,
}
)
staging.write("module_1", new_rows)
@staticmethod
def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
out = []
last_t: float | None = None
for row in sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
):
t = float(row["timestamp"])
if last_t is not None:
out[-1]["end"] = t
out.append({"text": row.get("content") or "", "start": t, "end": t})
last_t = t
return out
def _generate_subtasks(self, record: EpisodeRecord) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
prompt = load_prompt("module_1_subtasks").format(
episode_task=record.episode_task,
min_subtask_seconds=self.config.min_subtask_seconds,
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
)
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
clip = episode_clip_path(record, self.frame_provider, cache_dir)
video_block = (
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
if clip is not None
else []
)
else:
video_frames = self.frame_provider.video_for_episode(
record, self.config.max_video_frames
)
video_block = to_video_block(video_frames)
content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
spans = result.get("subtasks") if isinstance(result, dict) else None
if not spans:
return []
# clamp to [t0, t_last] and sort
t0 = record.frame_timestamps[0]
t_last = record.frame_timestamps[-1]
cleaned: list[dict[str, Any]] = []
for span in spans:
try:
start = float(span["start"])
end = float(span["end"])
text = str(span["text"]).strip()
except (KeyError, ValueError, TypeError):
continue
start = max(t0, min(start, t_last))
end = max(t0, min(end, t_last))
if end < start:
start, end = end, start
if not text:
continue
cleaned.append({"text": text, "start": start, "end": end})
cleaned.sort(key=lambda s: s["start"])
return cleaned
def _generate_plan(
self,
record: EpisodeRecord,
subtask_spans: Sequence[dict[str, Any]],
*,
refresh_t: float | None = None,
) -> str | None:
if not subtask_spans:
return None
subtasks_text = "\n".join(f"- {s['text']}" for s in subtask_spans)
prompt = load_prompt("module_1_plan").format(
episode_task=record.episode_task,
subtasks_text=subtasks_text,
plan_max_steps=self.config.plan_max_steps,
)
if refresh_t is not None:
prompt += f"\n\n(This is a plan refresh after a user interjection at t={refresh_t:.2f}s.)\n"
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("plan"), str):
return result["plan"].strip()
return None
def _generate_memory(
self,
record: EpisodeRecord,
prior_memory: str,
completed: str,
remaining: Sequence[str],
) -> str:
prompt = load_prompt("module_1_memory").format(
episode_task=record.episode_task,
prior_memory=prior_memory or "(none)",
completed_subtask=completed,
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
result = self.vlm.generate_json([messages])[0]
if isinstance(result, dict) and isinstance(result.get("memory"), str):
return result["memory"].strip()
return ""
@@ -1,33 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompt templates loaded as plain text.
One file per use site. Templates use ``str.format(**vars)`` substitution; we
intentionally avoid jinja2 here so the templates remain inspectable in
plain editors and roundtrip cleanly through ``ruff format``.
"""
from __future__ import annotations
from pathlib import Path
_DIR = Path(__file__).parent
def load(name: str) -> str:
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
path = _DIR / f"{name}.txt"
return path.read_text(encoding="utf-8")
@@ -1,25 +0,0 @@
You are updating the robot's compressed semantic memory at the boundary of
a completed subtask.
Reference (verbatim from MEM, Torne 2026):
"Remove or compress information in the language memory whenever
appropriate. Keep ONLY the minimal set of relevant information for future
task execution. Specific object attributes (colors, precise quantities of
each item) get discarded when their details won't affect subsequent
actions. Functional outcomes (where items went, how many) are preserved."
Concrete example from MEM:
Before: "I put a light green bowl, a dark blue bowl and a bright yellow
bowl into the top right cabinet"
After: "I placed three bowls in the top right cabinet"
Episode task: "{episode_task}"
Previous memory: {prior_memory}
Just-completed subtask: "{completed_subtask}"
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
Update the memory. Drop irrelevant detail. Compress completed steps.
Keep WHAT happened, drop HOW. Shorter is better.
Output strictly valid JSON:
{{ "memory": "<one or two short sentences>" }}
@@ -1,18 +0,0 @@
You are the high-level planner for a robot demonstrating: "{episode_task}".
Given the subtask decomposition below, write a concise hierarchical PLAN
the robot should follow. Format the plan as a numbered list, one line per
high-level step. The plan describes the full task; subtasks are the atomic
skills used to execute it.
Subtasks for context:
{subtasks_text}
Authoring rules:
- 3 to {plan_max_steps} steps.
- Each step describes one logical chunk of the task, not one motion.
- Steps must be in execution order.
- Plain prose, no JSON, no markdown headers.
Output strictly valid JSON:
{{ "plan": "1. ...\n2. ...\n3. ..." }}
@@ -1,33 +0,0 @@
You are labeling a teleoperated robot demonstration.
The user originally asked: "{episode_task}"
You are shown the entire demonstration as a single video. Watch the
whole clip, then segment it into a list of consecutive atomic subtasks
the robot performs.
Authoring rules — based on Hi Robot (Shi 2025) atom granularity and
Pi0.7 (Physical Intelligence 2025) "how, not what" detail:
- Each subtask is one atomic skill the low-level policy can execute,
e.g. "pick up one piece of lettuce", "place the bowl into the box",
"move the right arm to the left".
- Capture HOW the subtask is performed, not only WHAT — e.g. prefer
"grasp the handle of the sponge with the left hand" to "pick up the
sponge".
- Subtasks are non-overlapping and cover the full episode in order.
Choose the cut points yourself based on what you see in the video
(gripper open/close events, contact, regrasps, transitions).
- Each subtask spans at least {min_subtask_seconds} seconds.
- Do not exceed {max_steps} subtasks total.
- Every subtask's [start_time, end_time] must lie within
[0.0, {episode_duration}] seconds.
Output strictly valid JSON of shape:
{{
"subtasks": [
{{"text": "<how-not-what>", "start": <float>, "end": <float>}},
...
]
}}
@@ -1,10 +0,0 @@
The user just asked the robot: "{episode_task}".
Generate a short verbal acknowledgement the robot would speak back before
beginning the task. Style: confident, friendly, single short sentence.
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
"OK, starting with the sponge.", "Got it.".
Output strictly valid JSON:
{{ "text": "<the spoken acknowledgement>" }}
@@ -1,27 +0,0 @@
You are simulating a user mid-episode interruption for a robot doing:
"{episode_task}".
Synthesize ONE realistic interruption the user might say at this moment in
the demonstration, plus the robot's verbal acknowledgement.
Context (Hi Robot, Shi 2025) — interjections fall into one of these
scenario types:
- negative task: "actually skip X"
- situated correction: "that's not trash"
- specific constraint: "use less salt"
- preference: "could you also do Y"
Interruption rules:
- Must be plausible given the current subtask context.
- Must change the plan in a non-trivial way (a new constraint, skipped
step, or correction).
- One sentence each.
Current subtask context: {current_subtask}
Time into episode: {timestamp:.2f}s
Output strictly valid JSON:
{{
"interjection": "<single sentence the user says>",
"speech": "<single sentence the robot speaks back>"
}}
@@ -1,32 +0,0 @@
You are generating a frame-grounded visual question/answer pair for
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
Policies — both train policies on grounded features such as bounding box
pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
"value": "<observed value>"}}
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
"above|below|near>", "object": "<obj>"}}
Generate a question of type "{question_type}". Output strictly valid JSON:
{{
"question": "<short, frame-grounded question>",
"answer": <object whose shape matches the schema above>
}}
@@ -1,219 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Datatrove-shaped reader.
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
episode containing:
- ``episode_index``: int
- ``frame_timestamps``: tuple[float, ...]
- ``frame_indices``: tuple[int, ...]
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
- ``data_path``: pathlib.Path of the source parquet shard
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
This shape lets each module operate per-episode without loading all parquet
rows into memory at once. It deliberately does not depend on datatrove
datatrove integration wraps this generator inside a ``PipelineStep`` in
:mod:`.executor`.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow.parquet as pq
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
@dataclass
class EpisodeRecord:
"""Per-episode record yielded by the reader."""
episode_index: int
episode_task: str
frame_timestamps: tuple[float, ...]
frame_indices: tuple[int, ...]
data_path: Path
row_offset: int # row offset within the parquet file where this episode starts
row_count: int # number of rows for this episode
def frames_df(self): # type: ignore[no-untyped-def]
"""Lazy-load the pandas slice for this episode."""
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
table = pq.read_table(self.data_path)
df: pd.DataFrame = table.to_pandas()
slice_ = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(drop=True)
return slice_
def _load_tasks_lookup(root: Path) -> dict[int, str]:
tasks_path = root / DEFAULT_TASKS_PATH
if not tasks_path.exists():
return {}
table = pq.read_table(tasks_path)
cols = {name: table.column(name).to_pylist() for name in table.column_names}
if "task_index" in cols and "task" in cols:
return dict(zip(cols["task_index"], cols["task"], strict=True))
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
Episodes are yielded in ascending ``episode_index`` order. The reader does
not assume a specific chunk/file layout: it scans every ``*.parquet``
under ``data/`` and groups by ``episode_index``.
"""
tasks = _load_tasks_lookup(root)
data_dir = root / "data"
parquet_files = sorted(data_dir.rglob("*.parquet"))
only_set = set(only_episodes) if only_episodes is not None else None
for path in parquet_files:
yield from _iter_one_path(path, tasks, only_set)
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
table = pq.read_table(path)
names = table.column_names
if "episode_index" not in names:
return
episode_col = table.column("episode_index").to_pylist()
timestamp_col = (
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
)
frame_col = (
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
)
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
def _build(
ep: int,
start: int,
end: int,
task_idx: int | None,
ts_buf: list[float],
fi_buf: list[int],
) -> EpisodeRecord | None:
if only_set is not None and ep not in only_set:
return None
task = tasks.get(task_idx, "") if task_idx is not None else ""
return EpisodeRecord(
episode_index=ep,
episode_task=task,
frame_timestamps=tuple(ts_buf),
frame_indices=tuple(fi_buf),
data_path=path,
row_offset=start,
row_count=end - start,
)
cur_ep: int | None = None
start_offset = 0
ts_buf: list[float] = []
fi_buf: list[int] = []
cur_task_idx: int | None = None
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
continue
if ep != cur_ep:
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
cur_ep = ep
start_offset = i
ts_buf = [timestamp_col[i]]
fi_buf = [frame_col[i]]
cur_task_idx = task_col[i] if task_col is not None else None
else:
ts_buf.append(timestamp_col[i])
fi_buf.append(frame_col[i])
if cur_ep is not None:
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
if rec is not None:
yield rec
def gather_data_paths(root: Path) -> list[Path]:
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
return sorted((root / "data").rglob("*.parquet"))
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
table = pq.read_table(path, columns=["episode_index"])
episode_col = table.column("episode_index").to_pylist()
out: dict[int, tuple[int, int]] = {}
cur_ep: int | None = None
start = 0
for i, ep in enumerate(episode_col):
if cur_ep is None:
cur_ep = ep
start = i
continue
if ep != cur_ep:
out[cur_ep] = (start, i - start)
cur_ep = ep
start = i
if cur_ep is not None:
out[cur_ep] = (start, len(episode_col) - start)
return out
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
n = record.row_count
if k <= 0 or n == 0:
return []
if k >= n:
return list(range(n))
step = (n - 1) / (k - 1) if k > 1 else 0.0
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
for path in gather_data_paths(root):
offsets = episode_offsets_per_path(path)
if episode_index in offsets:
start, count = offsets[episode_index]
return path, start, count
return None
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
found = lookup_data_path(root, episode_index)
if found is None:
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
path, start, count = found
table = pq.read_table(path, columns=["timestamp"])
timestamps = table.column("timestamp").to_pylist()[start : start + count]
return path, [float(t) for t in timestamps]
@@ -1,98 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-episode staging.
Each module writes its raw output as a JSONL file under
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
staging tree and partitions rows into the two language columns.
JSONL is preferred over parquet here because the staging artifact is meant to
be human-inspectable, easy to diff between prompt iterations, and trivially
appended to. The final dataset format is parquet; staging is just an
intermediate.
"""
from __future__ import annotations
import json
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Any
ModuleName = str
_MODULES: tuple[ModuleName, ...] = (
"module_1",
"module_2",
"module_3",
)
@dataclass
class EpisodeStaging:
"""Filesystem layout for a single episode's staged module outputs."""
root: Path
episode_index: int
@property
def episode_dir(self) -> Path:
return self.root / f"episode_{self.episode_index:06d}"
def path_for(self, module: ModuleName) -> Path:
if module not in _MODULES:
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
return self.episode_dir / f"{module}.jsonl"
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
path = self.path_for(module)
if not path.exists():
return []
out: list[dict[str, Any]] = []
with path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
out.append(json.loads(line))
return out
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
return {m: self.read(m) for m in _MODULES}
def has(self, module: ModuleName) -> bool:
return self.path_for(module).exists()
def iter_staged_episodes(root: Path) -> Iterator[int]:
"""Yield episode indices for which any staging artifact exists."""
if not root.exists():
return
for child in sorted(root.iterdir()):
if child.is_dir() and child.name.startswith("episode_"):
try:
yield int(child.name.removeprefix("episode_"))
except ValueError:
continue
@@ -1,271 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre-write validation against staged outputs.
Runs after Modules 13 have all written their per-episode artifacts but
*before* the writer rewrites parquet shards. The validator never touches
parquet; it only inspects the staging tree and the source frame timestamps
exposed by :class:`EpisodeRecord`.
Checks (per the plan's "Intermediate staging and validation" section):
- exact timestamp alignment against source frame timestamps
- no orphan speech / interjection pairs
- plan / memory emission consistency (events have a paired persistent row)
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
attribute / spatial)
- every row maps to its correct column under :func:`column_for_style`
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import (
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
column_for_style,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
@dataclass
class ValidationReport:
"""Outcome of one validation pass across all episodes."""
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
episodes_checked: int = 0
@property
def ok(self) -> bool:
return not self.errors
def add_error(self, message: str) -> None:
self.errors.append(message)
def add_warning(self, message: str) -> None:
self.warnings.append(message)
def summary(self) -> str:
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
def classify_vqa_answer(payload: Any) -> str | None:
"""Best-effort classification of a VQA answer payload to a question type."""
if not isinstance(payload, dict):
return None
keys = set(payload.keys())
for kind, required in VQA_ANSWER_SHAPES.items():
if required.issubset(keys):
return kind
return None
@dataclass
class StagingValidator:
"""Walks the staging tree and produces a :class:`ValidationReport`."""
timestamp_atol: float = 0.0 # exact-match by default
def validate(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
) -> ValidationReport:
report = ValidationReport()
for record in records:
self._validate_episode(record, staging_dir, report)
report.episodes_checked += 1
return report
def _validate_episode(
self,
record: EpisodeRecord,
staging_dir: Path,
report: ValidationReport,
) -> None:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged = staging.read_all()
all_rows: list[dict[str, Any]] = []
for module_name, rows in staged.items():
for row in rows:
row = {**row, "_module": module_name}
all_rows.append(row)
frame_ts = set(record.frame_timestamps)
events: list[dict[str, Any]] = []
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
for row in events:
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
self._check_speech_interjection_pairs(events, report, record.episode_index)
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
self._check_vqa_json(events, report, record.episode_index)
def _check_column_routing(
self,
row: dict[str, Any],
report: ValidationReport,
episode_index: int,
) -> None:
style = row.get("style")
module = row.get("_module")
try:
target_col = column_for_style(style)
except ValueError:
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
return
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
report.add_error(
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
)
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
report.add_error(
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
)
def _check_event_timestamp_alignment(
self,
row: dict[str, Any],
frame_ts: set[float],
report: ValidationReport,
episode_index: int,
) -> None:
ts = row.get("timestamp")
if ts is None:
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
return
if self.timestamp_atol == 0.0:
if float(ts) not in frame_ts:
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
)
else:
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
report.add_error(
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
)
def _check_speech_interjection_pairs(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
speech_ts: dict[float, int] = {}
interjection_ts: dict[float, int] = {}
for row in events:
ts = row.get("timestamp")
if ts is None:
continue
ts_f = float(ts)
if row.get("style") is None and row.get("role") == "assistant":
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
if row.get("style") == "interjection":
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
for ts in interjection_ts:
if ts not in speech_ts:
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
def _check_plan_memory_consistency(
self,
persistent: Sequence[dict[str, Any]],
events: Sequence[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
interjection_ts = sorted(
{
float(r["timestamp"])
for r in events
if r.get("style") == "interjection" and r.get("timestamp") is not None
}
)
if persistent and not plan_ts:
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
# every interjection should have a same-timestamp plan refresh
for ts in interjection_ts:
if ts not in set(plan_ts):
report.add_error(
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
)
# memory should be emitted at subtask boundaries (subset relation)
if memory_ts and subtask_ts:
mem_set = set(memory_ts)
sub_set = set(subtask_ts)
stray = sorted(mem_set - sub_set)
if stray:
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
def _check_vqa_json(
self,
events: Iterable[dict[str, Any]],
report: ValidationReport,
episode_index: int,
) -> None:
for row in events:
if row.get("style") != "vqa" or row.get("role") != "assistant":
continue
content = row.get("content")
if content is None:
report.add_error(
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
)
continue
try:
payload = json.loads(content)
except (TypeError, ValueError) as exc:
report.add_error(
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
)
continue
shape = classify_vqa_answer(payload)
if shape is None:
report.add_error(
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
)
@@ -1,453 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared Qwen-VL client.
The pipeline uses a single shared VLM across modules. vLLM is preferred when
available (high throughput, JSON-guided decoding); transformers is the
fallback. A ``stub`` backend is used for unit tests so fixtures never call
into a real model.
The client speaks one method, :meth:`VlmClient.generate_json`, which:
- accepts a list of OpenAI/HF-style multimodal messages,
- requests JSON output (``json_mode=True`` enables guided decoding when the
backend supports it),
- batches requests transparently,
- and reprompts once on a JSON parse failure with an inline correction
message before raising.
"""
from __future__ import annotations
import json
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Protocol
from .config import VlmConfig
class VlmClient(Protocol):
"""Protocol every backend must implement."""
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
"""Generate one JSON-decoded response per messages list."""
@dataclass
class StubVlmClient:
"""Deterministic stub used in unit tests.
A test passes a callable that maps the *last user message text* (or, if
that is empty, the full message list) to a JSON-serializable response.
"""
responder: Callable[[Sequence[dict[str, Any]]], Any]
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
return [self.responder(list(messages)) for messages in messages_batch]
def _strip_to_json(text: str) -> Any:
text = text.strip()
if text.startswith("```"):
# tolerate ```json ... ``` fences from chat-tuned backbones
first = text.find("\n")
last = text.rfind("```")
if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip()
return json.loads(text)
@dataclass
class _GenericTextClient:
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
config: VlmConfig
def generate_json(
self,
messages_batch: Sequence[Sequence[dict[str, Any]]],
*,
max_new_tokens: int | None = None,
temperature: float | None = None,
) -> list[Any]:
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
temp = temperature if temperature is not None else self.config.temperature
raw = self.generate_text(messages_batch, max_tok, temp)
out: list[Any] = []
for messages, text in zip(messages_batch, raw, strict=True):
try:
out.append(_strip_to_json(text))
continue
except (ValueError, json.JSONDecodeError):
pass
retry = list(messages) + [
{"role": "assistant", "content": text},
{
"role": "user",
"content": (
"Your previous reply was not valid JSON. "
"Reply with strictly valid JSON, no prose, no fences."
),
},
]
retry_text = self.generate_text([retry], max_tok, temp)[0]
out.append(_strip_to_json(retry_text))
return out
def make_vlm_client(config: VlmConfig) -> VlmClient:
"""Build the shared VLM client per the configured backend.
For ``stub``, callers should construct :class:`StubVlmClient` directly with
a responder callable. ``stub`` here is rejected to make accidental misuse
obvious.
"""
if config.backend == "stub":
raise ValueError(
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
)
if config.backend == "vllm":
return _make_vllm_client(config)
if config.backend == "transformers":
return _make_transformers_client(config)
if config.backend == "openai":
return _make_openai_client(config)
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
def _make_vllm_client(config: VlmConfig) -> VlmClient:
try:
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
) from exc
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
# convolution kernels — slower but functional.
import os as _os # noqa: PLC0415
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
import torch as _torch # noqa: PLC0415
_torch.backends.cudnn.enabled = False
llm_kwargs: dict[str, Any] = {
"model": config.model_id,
"tensor_parallel_size": config.tensor_parallel_size,
"gpu_memory_utilization": config.gpu_memory_utilization,
"trust_remote_code": config.trust_remote_code,
}
if config.max_model_len is not None:
llm_kwargs["max_model_len"] = config.max_model_len
llm = LLM(**llm_kwargs)
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
# ``guided_decoding`` would speed up parsing but its API differs across
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
# wrapper already has a one-retry JSON-recovery path, so we skip it.
params = SamplingParams(max_tokens=max_tok, temperature=temp)
# ``llm.chat`` handles chat-template application + multimodal input
# extraction (image/video blocks) internally, which ``llm.generate``
# does not.
outputs = llm.chat([list(m) for m in batch], params)
return [o.outputs[0].text for o in outputs]
return _GenericTextClient(_gen, config)
def _make_transformers_client(config: VlmConfig) -> VlmClient:
try:
import torch # type: ignore[import-not-found]
import transformers # type: ignore[import-not-found]
from transformers import AutoProcessor # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
auto_cls = (
getattr(transformers, "AutoModelForImageTextToText", None)
or getattr(transformers, "AutoModelForVision2Seq", None)
)
if auto_cls is None:
raise ImportError(
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
"for VL models."
)
processor = AutoProcessor.from_pretrained(
config.model_id, trust_remote_code=config.trust_remote_code
)
import os as _os # noqa: PLC0415
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
# post-load dispatch path (the alloc fails in accelerate's hook setup
# even with TBs of host RAM). Default to manual: load on CPU with
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
if use_accelerate:
model = auto_cls.from_pretrained(
config.model_id,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=config.trust_remote_code,
)
else:
import torch as _torch # noqa: PLC0415
model = auto_cls.from_pretrained(
config.model_id,
torch_dtype=_torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=config.trust_remote_code,
)
model = model.to("cuda")
model.eval()
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
outs: list[str] = []
for messages in batch:
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=[text], return_tensors="pt").to(model.device)
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=max_tok,
temperature=temp,
do_sample=temp > 0.0,
)
decoded = processor.batch_decode(
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
)[0]
outs.append(decoded)
return outs
return _GenericTextClient(_gen, config)
def _make_openai_client(config: VlmConfig) -> VlmClient:
"""Backend that talks to any OpenAI-compatible server.
Compatible with ``vllm serve``, ``transformers serve``,
``ktransformers serve``, and hosted endpoints. By default the server
is expected to be already running. Set ``auto_serve=True`` to have
this client spawn one (default: ``transformers serve``), wait until
it's ready, and tear it down on process exit.
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
auto-converted to ``image_url`` data-URLs. Video blocks
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
multi-frame ``video_url`` items where supported.
"""
try:
from openai import OpenAI # type: ignore[import-not-found]
except ImportError as exc:
raise ImportError(
"openai package is required for backend='openai'. "
"Install with `pip install openai`."
) from exc
api_base = config.api_base
print(
f"[lerobot-annotate] backend=openai model={config.model_id} "
f"api_base={api_base} auto_serve={config.auto_serve}",
flush=True,
)
if config.auto_serve:
if _server_is_up(api_base):
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
else:
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
api_base = _spawn_inference_server(config)
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
client = OpenAI(base_url=api_base, api_key=config.api_key)
def _gen(
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
) -> list[str]:
outs: list[str] = []
for messages in batch:
api_messages = [_to_openai_message(m) for m in messages]
response = client.chat.completions.create(
model=config.model_id,
messages=api_messages,
max_tokens=max_tok,
temperature=temp,
)
outs.append(response.choices[0].message.content or "")
return outs
return _GenericTextClient(_gen, config)
def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
import urllib.request # noqa: PLC0415
url = api_base.rstrip("/") + "/models"
try:
with urllib.request.urlopen(url, timeout=2) as resp:
return resp.status == 200
except Exception: # noqa: BLE001
return False
def _spawn_inference_server(config: VlmConfig) -> str:
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
accepts ``/v1/models``, and register a shutdown hook.
Streams the server's stdout/stderr to the parent terminal in
real-time on a background thread so users can see model-load
progress and errors as they happen.
Returns the full ``api_base`` URL the OpenAI client should use.
"""
import atexit # noqa: PLC0415
import shlex # noqa: PLC0415
import signal # noqa: PLC0415
import subprocess # noqa: PLC0415
import sys # noqa: PLC0415
import threading # noqa: PLC0415
import time # noqa: PLC0415
import urllib.request # noqa: PLC0415
cmd = config.serve_command
if not cmd:
cmd = (
f"transformers serve {shlex.quote(config.model_id)} "
f"--port {config.serve_port} --continuous-batching"
)
api_base = f"http://localhost:{config.serve_port}/v1"
print(f"[server] launching: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
# Watch the server output for the uvicorn readiness banner. This is
# more reliable than polling /v1/models because transformers serve
# rescans its cache on every model-list request, which can exceed
# the urllib timeout and trigger an infinite probe loop.
ready_event = threading.Event()
ready_markers = ("Uvicorn running", "Application startup complete")
def _stream_output() -> None:
assert proc.stdout is not None
for line in proc.stdout:
sys.stdout.write(f"[server] {line}")
sys.stdout.flush()
if any(marker in line for marker in ready_markers):
ready_event.set()
threading.Thread(target=_stream_output, daemon=True).start()
def _shutdown() -> None:
if proc.poll() is None:
print(f"[server] stopping pid={proc.pid}", flush=True)
proc.send_signal(signal.SIGINT)
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while time.monotonic() < deadline:
if proc.poll() is not None:
raise RuntimeError(
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
f"See [server] log lines above for the cause."
)
if ready_event.wait(timeout=2):
return api_base
proc.terminate()
raise RuntimeError(
f"[server] did not become ready within {config.serve_ready_timeout_s}s"
)
def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]:
"""Convert an internal message dict to OpenAI chat format.
Internal image/video blocks (using PIL.Image objects) become
OpenAI ``image_url``/``video_url`` items via base64 data URLs.
"""
content = message.get("content")
if not isinstance(content, list):
return {"role": message["role"], "content": content}
out_blocks: list[dict[str, Any]] = []
for block in content:
block_type = block.get("type") if isinstance(block, dict) else None
if block_type == "text":
out_blocks.append({"type": "text", "text": block.get("text", "")})
elif block_type == "image":
out_blocks.append(
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
)
elif block_type == "video":
frames = block.get("video", [])
for img in frames:
out_blocks.append(
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
)
elif block_type == "video_url":
# Pass through to the OpenAI-compatible server unchanged.
out_blocks.append({"type": "video_url", "video_url": block["video_url"]})
else:
out_blocks.append(block)
return {"role": message["role"], "content": out_blocks}
def _pil_to_data_url(image: Any) -> str:
"""Encode a PIL.Image as a base64 data URL."""
import base64 # noqa: PLC0415
import io # noqa: PLC0415
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/png;base64,{b64}"
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
"""Pass-through hook used by the vllm backend.
vllm exposes its own multimodal entry points that vary by version; for the
base flow we simply forward the raw message list and let the caller's
custom backend handle templating. Real deployments override this.
"""
return list(messages)
@@ -1,339 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Final parquet rewrite.
For every episode the writer:
1. reads the staged module outputs,
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
3. sorts each slice deterministically,
4. broadcasts the persistent slice across every frame in the episode,
5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column,
7. adds a top-level ``tools`` column containing the JSON schema for ``say``,
8. writes the parquet shard back in place.
Invariants enforced here (and re-checked by the validator):
- per-episode persistent slice is byte-identical across every frame;
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
(timestamps come straight from the source parquet never recomputed);
- every row passes ``column_for_style(style)``.
"""
from __future__ import annotations
import json
import logging
from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
)
from .reader import EpisodeRecord
from .staging import EpisodeStaging
logger = logging.getLogger(__name__)
SAY_TOOL_SCHEMA: dict[str, Any] = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
# events are bucketed per-frame, but within a frame we still want determinism
return (row.get("style") or "", row.get("role") or "")
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the persistent column's struct shape."""
style = row.get("style")
if style not in PERSISTENT_STYLES:
raise ValueError(
f"persistent slice contains row with non-persistent style {style!r}; "
"row would be misrouted under column_for_style()"
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"timestamp": float(row["timestamp"]),
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
style = row.get("style")
if style is not None and style not in EVENT_ONLY_STYLES:
raise ValueError(
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
return {
"role": str(row["role"]),
"content": None if row.get("content") is None else str(row["content"]),
"style": style,
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
}
def _normalize_tool_calls(value: Any) -> list[Any] | None:
if value is None:
return None
if not isinstance(value, list):
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
return list(value)
def _validate_atom_invariants(row: dict[str, Any]) -> None:
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
has_content = row.get("content") is not None
has_tools = row.get("tool_calls") is not None
if not (has_content or has_tools):
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
if row.get("style") is None and not has_tools:
raise ValueError(f"style=None requires tool_calls: {row!r}")
def _validate_speech_atom(row: dict[str, Any]) -> None:
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
if row.get("style") is not None:
return # not a speech atom
if row.get("role") != "assistant":
raise ValueError(f"speech atom must have role=assistant: {row!r}")
if row.get("content") is not None:
raise ValueError(f"speech atom must have content=null: {row!r}")
tool_calls = row.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
first = tool_calls[0]
if not isinstance(first, dict):
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
if first.get("type") != "function":
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
fn = first.get("function") or {}
if fn.get("name") != "say":
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
args = fn.get("arguments") or {}
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
@dataclass
class LanguageColumnsWriter:
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
drop_existing_subtask_index: bool = True
def write_all(
self,
records: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> list[Path]:
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
for record in records:
episodes_by_path[record.data_path].append(record)
written: list[Path] = []
for path, eps in episodes_by_path.items():
self._rewrite_one(path, eps, staging_dir, root)
written.append(path)
return written
def _rewrite_one(
self,
path: Path,
episodes: Sequence[EpisodeRecord],
staging_dir: Path,
root: Path,
) -> None:
table = pq.read_table(path)
n_rows = table.num_rows
# Ensure we cover every episode in the file. Episodes that don't have
# staging artifacts are passed through with empty annotation lists —
# this keeps the writer idempotent and safe for partial reruns.
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
for record in episodes:
staging = EpisodeStaging(staging_dir, record.episode_index)
staged_per_ep[record.episode_index] = staging.read_all()
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
for ep_index, ep_staged in staged_per_ep.items():
persistent_rows: list[dict[str, Any]] = []
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
for _module_name, rows in ep_staged.items():
for row in rows:
style = row.get("style")
if column_for_style(style) == LANGUAGE_PERSISTENT:
persistent_rows.append(row)
else:
event_rows.append(row)
persistent_rows.sort(key=_row_persistent_sort_key)
normalized_persistent = []
for r in persistent_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
normalized_persistent.append(_normalize_persistent_row(r))
persistent_by_ep[ep_index] = normalized_persistent
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
for r in event_rows:
_validate_atom_invariants(r)
_validate_speech_atom(r)
ts = float(r["timestamp"])
buckets[ts].append(_normalize_event_row(r))
for ts in list(buckets.keys()):
buckets[ts].sort(key=_row_event_sort_key)
events_by_ep_ts[ep_index] = buckets
episode_col = (
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
)
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
if episode_col is None or ts_col is None:
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
per_row_persistent: list[list[dict[str, Any]]] = []
per_row_events: list[list[dict[str, Any]]] = []
for i in range(n_rows):
ep = episode_col[i]
ts = float(ts_col[i])
per_row_persistent.append(persistent_by_ep.get(ep, []))
buckets = events_by_ep_ts.get(ep, {})
per_row_events.append(buckets.get(ts, []))
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
pq.write_table(new_table, path)
def _materialize_table(
self,
table: pa.Table,
persistent: list[list[dict[str, Any]]],
events: list[list[dict[str, Any]]],
*,
drop_old: bool,
) -> pa.Table:
cols = []
names = []
for name in table.column_names:
if drop_old and name == "subtask_index":
continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"):
continue # we'll re-add canonical versions
cols.append(table.column(name))
names.append(name)
# We let pyarrow infer struct/list schema rather than passing the
# canonical type from `lerobot.datasets.language` directly: that type
# uses `pa.json_()` for the `tool_calls` element type, which
# `pa.array(..., type=...)` cannot materialize from Python lists on
# current pyarrow versions. The inferred schema round-trips through
# parquet and `LeRobotDataset` correctly — see PR 1's
# `tests/datasets/test_language.py` which exercises the same flow.
persistent_arr = pa.array(persistent)
events_arr = pa.array(events)
cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
# Dataset-level tools column. Store the JSON schema as a string per
# row (broadcast-identical, parquet dictionary-encodes it) — string
# storage avoids requiring pa.json_() on every consumer.
tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True)
tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string())
cols.append(tools_arr)
names.append("tools")
return pa.Table.from_arrays(cols, names=names)
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
"""Build a canonical speech tool-call atom for the events column."""
return {
"role": "assistant",
"content": None,
"style": None,
"timestamp": float(timestamp),
"tool_calls": [
{
"type": "function",
"function": {
"name": "say",
"arguments": {"text": text},
},
}
],
}
def normalize_rows_for_writer(
rows: Iterable[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Helper used by tests/validators to partition a flat row list into
(persistent_rows, event_rows) using ``column_for_style``.
"""
persistent: list[dict[str, Any]] = []
events: list[dict[str, Any]] = []
for row in rows:
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
events.append(row)
return persistent, events
@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
"""
import logging
import sys
import time
from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
from .configuration_realsense import RealSenseCameraConfig
logger = logging.getLogger(__name__)
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
class RealSenseCamera(Camera):
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
Args:
config: The configuration settings for the camera.
"""
require_package("pyrealsense2", extra="intelrealsense")
require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
super().__init__(config)
self.config = config
@@ -131,6 +133,9 @@ class RealSenseCamera(Camera):
self.rs_pipeline: rs.pipeline | None = None
self.rs_profile: rs.pipeline_profile | None = None
# Meters per uint16 unit on the depth stream. Queried from the device
# at connect() time. Typical D-series value is 0.001 (= 1 mm/unit).
self.depth_scale: float | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
@@ -188,6 +193,17 @@ class RealSenseCamera(Camera):
) from e
self._configure_capture_settings()
# Query depth scale (meters per uint16 unit) when depth is enabled so
# consumers can convert the raw z16 stream to metric distances.
if self.use_depth and self.rs_profile is not None:
try:
depth_sensor = self.rs_profile.get_device().first_depth_sensor()
self.depth_scale = float(depth_sensor.get_depth_scale())
except RuntimeError as e:
logger.warning(f"{self}: failed to query depth scale ({e}); falling back to 0.001 m/unit.")
self.depth_scale = 0.001
self._start_read_thread()
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
@@ -530,7 +546,6 @@ class RealSenseCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
@@ -573,7 +588,6 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
@@ -609,6 +623,78 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
"""Read the latest depth frame asynchronously, in metric meters.
Mirrors :meth:`async_read` but returns the depth stream rather than the
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
the background read thread is not running.
TimeoutError: If no frame becomes available within ``timeout_ms``.
"""
if not self.use_depth:
raise RuntimeError(
f"{self}: cannot read depth — camera was configured with use_depth=False."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(
f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms."
)
with self.frame_lock:
depth_frame = self.latest_depth_frame
self.new_frame_event.clear()
if depth_frame is None:
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
return depth_frame
@check_if_not_connected
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
"""Return the most recent depth frame in metric meters (peeking).
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
Output is ``np.float32`` of shape ``(H, W)`` in meters.
Raises:
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
no depth frame has been captured yet.
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
"""
if not self.use_depth:
raise RuntimeError(
f"{self}: cannot read depth — camera was configured with use_depth=False."
)
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
with self.frame_lock:
depth_frame = self.latest_depth_frame
timestamp = self.latest_timestamp
if depth_frame is None or timestamp is None:
raise RuntimeError(f"{self} has not captured any depth frames yet.")
age_ms = (time.perf_counter() - timestamp) * 1e3
if age_ms > max_age_ms:
raise TimeoutError(
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
)
return depth_frame
def disconnect(self) -> None:
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.
@@ -632,6 +718,8 @@ class RealSenseCamera(Camera):
self.rs_pipeline = None
self.rs_profile = None
self.depth_scale = None
with self.frame_lock:
self.latest_color_frame = None
self.latest_depth_frame = None
+5 -1
View File
@@ -41,8 +41,12 @@ def cfg_to_group(
return tag
return tag[:max_tag_length]
if cfg.is_reward_model_training:
trainable_tag = f"reward_model:{cfg.reward_model.type}"
else:
trainable_tag = f"policy:{cfg.policy.type}"
lst = [
f"policy:{cfg.policy.type}",
trainable_tag,
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
+2 -4
View File
@@ -21,9 +21,9 @@ are intentionally NOT re-exported here to avoid circular dependencies
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
"""
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -40,12 +40,10 @@ __all__ = [
"PolicyFeature",
"RTCAttentionSchedule",
# Config classes
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
]
+80
View File
@@ -0,0 +1,80 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str = ""
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str = ""
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
def stamp_repo_id(self) -> None:
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
Must be called explicitly at dataset *creation* time not on resume,
where the existing ``repo_id`` (already stamped) must be preserved.
"""
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+2 -2
View File
@@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from lerobot.transforms import ImageTransformsConfig
from lerobot.utils.import_utils import get_safe_default_codec
from lerobot.utils.import_utils import get_safe_default_video_backend
@dataclass
@@ -34,7 +34,7 @@ class DatasetConfig:
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
video_backend: str = field(default_factory=get_safe_default_video_backend)
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
-193
View File
@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
if self.stream is not None and self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(_PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(_PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)
@@ -1,47 +0,0 @@
blend:
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.16
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
high_level_subtask:
weight: 0.15
bindings:
next_subtask: "nth_next(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
low_level_execution:
weight: 0.35
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
ask_vqa:
weight: 0.20
messages:
- {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+163
View File
@@ -0,0 +1,163 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RewardModelConfig")
logger = logging.getLogger(__name__)
@dataclass
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Base configuration for reward models.
Args:
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
"""
# Reuses PolicyFeature
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None
pretrained_path: str | None = None
push_to_hub: bool = False
repo_id: str | None = None
# Hub metadata
license: str | None = None
tags: list[str] | None = None
private: bool | None = None
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
@property
def type(self) -> str:
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@property
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@property
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
return None
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
pass
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**reward_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
# HACK: Parse the original config to get the config subclass, so that we can
# apply cli overrides.
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
with open(config_file) as f:
config = json.load(f)
config.pop("type", None)
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(config, f)
config_file = f.name
cli_overrides = reward_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
+87 -29
View File
@@ -13,7 +13,9 @@
# limitations under the License.
import builtins
import datetime as dt
import json
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@@ -26,18 +28,57 @@ from lerobot import envs
from lerobot.configs import parser
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
TRAIN_CONFIG_NAME = "train_config.json"
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
legacy_fields = (
"use_rabc",
"rabc_progress_path",
"rabc_kappa",
"rabc_epsilon",
"rabc_head_mode",
)
if not any(key in config for key in legacy_fields):
return None
migrated_config = dict(config)
use_rabc = bool(migrated_config.pop("use_rabc", False))
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
rabc_kappa = migrated_config.pop("rabc_kappa", None)
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
# New configs may already define sample_weighting explicitly. In that case,
# legacy fields are ignored after being stripped from the payload.
if migrated_config.get("sample_weighting") is None and use_rabc:
sample_weighting: dict[str, Any] = {"type": "rabc"}
if rabc_progress_path is not None:
sample_weighting["progress_path"] = rabc_progress_path
if rabc_kappa is not None:
sample_weighting["kappa"] = rabc_kappa
if rabc_epsilon is not None:
sample_weighting["epsilon"] = rabc_epsilon
if rabc_head_mode is not None:
sample_weighting["head_mode"] = rabc_head_mode
migrated_config["sample_weighting"] = sample_weighting
return migrated_config
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
reward_model: RewardModelConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
@@ -72,27 +113,41 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
@property
def is_reward_model_training(self) -> bool:
"""True when the config targets a reward model rather than a policy."""
return self.reward_model is not None
@property
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
"""Return whichever config (policy or reward_model) is active."""
if self.is_reward_model_training:
return self.reward_model # type: ignore[return-value]
return self.policy # type: ignore[return-value]
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
reward_model_path = parser.get_path_arg("reward_model")
if reward_model_path:
cli_overrides = parser.get_cli_overrides("reward_model")
self.reward_model = RewardModelConfig.from_pretrained(
reward_model_path, cli_overrides=cli_overrides
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
@@ -108,18 +163,22 @@ class TrainPipelineConfig(HubMixin):
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
if self.policy is None:
if self.policy is None and self.reward_model is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
"Neither policy nor reward_model is configured. "
"Please specify one with `--policy.path` or `--reward_model.path`."
)
active_cfg = self.trainable_config
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
self.job_name = f"{active_cfg.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
self.job_name = f"{self.env.type}_{active_cfg.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
@@ -137,26 +196,16 @@ class TrainPipelineConfig(HubMixin):
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
self.optimizer = active_cfg.get_optimizer_preset()
self.scheduler = active_cfg.get_scheduler_preset()
if self.policy.push_to_hub and not self.policy.repo_id:
raise ValueError(
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
else:
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
"""Keys for draccus pretrained-path loading."""
return ["policy", "reward_model"]
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
@@ -207,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
) from e
cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
if migrated_config is not None:
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
json.dump(migrated_config, f)
config_file = f.name
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
+19 -15
View File
@@ -37,21 +37,24 @@ from .dataset_tools import (
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import (
check_video_encoder_config_pyav,
detect_available_encoders_pyav,
get_codec,
)
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
from .video_utils import (
DepthEncoderConfig,
VideoEncoderConfig,
VideoEncodingManager,
camera_encoder_defaults,
depth_encoder_defaults,
)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here.
@@ -61,26 +64,27 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"DepthEncoderConfig",
"VideoEncoderConfig",
"VideoEncodingManager",
"camera_encoder_defaults",
"depth_encoder_defaults",
"add_features",
"aggregate_datasets",
"aggregate_pipeline_dataset_features",
"aggregate_stats",
"check_video_encoder_config_pyav",
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"detect_available_encoders_pyav",
"get_codec",
"get_feature_stats",
"load_episodes",
"make_dataset",
+14 -18
View File
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["index"] = df["index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
df["index"] = df["index"] + dst_meta.info.total_frames
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
@@ -225,9 +225,9 @@ def update_meta_data(
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
return df
@@ -237,8 +237,8 @@ def aggregate_datasets(
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
chunk_size: int | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
@@ -313,8 +313,8 @@ def aggregate_datasets(
# to avoid interference between different source datasets
data_idx.pop("src_to_dst", None)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
dst_meta.info.total_episodes += src_meta.total_episodes
dst_meta.info.total_frames += src_meta.total_frames
finalize_aggregation(dst_meta, all_metadata)
logging.info("Aggregation complete.")
@@ -332,7 +332,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx: Dictionary tracking video chunk and file indices.
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
@@ -417,6 +416,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
concatenate_video_files(
[dst_path, src_path],
dst_path,
compatibility_check=True,
)
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
write_tasks(aggr_meta.tasks, aggr_meta.root)
logging.info("write info")
aggr_meta.info.update(
{
"total_tasks": len(aggr_meta.tasks),
"total_episodes": sum(m.total_episodes for m in all_metadata),
"total_frames": sum(m.total_frames for m in all_metadata),
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
}
)
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
write_info(aggr_meta.info, aggr_meta.root)
logging.info("write stats")
+1 -1
View File
@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] in {"string", "language"}:
if features[key]["dtype"] == "string":
continue
if features[key]["dtype"] in ["image", "video"]:
+61 -27
View File
@@ -34,22 +34,21 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_json,
write_stats,
write_tasks,
)
from .utils import (
DEFAULT_EPISODES_PATH,
INFO_PATH,
check_version_compatibility,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
from .video_utils import get_video_info
from .video_utils import VideoEncoderConfig, get_video_info
CODEBASE_VERSION = "v3.0"
@@ -176,6 +175,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -226,7 +226,7 @@ class LeRobotDatasetMetadata:
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
return packaging.version.parse(self.info.codebase_version)
def get_data_file_path(self, ep_index: int) -> Path:
"""Return the relative parquet file path for the given episode index.
@@ -281,27 +281,27 @@ class LeRobotDatasetMetadata:
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
return self.info.data_path
@property
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["video_path"]
return self.info.video_path
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
return self.info["robot_type"]
return self.info.robot_type
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
return self.info.fps
@property
def features(self) -> dict[str, dict]:
"""All features contained in the dataset."""
return self.info["features"]
return self.info.features
@property
def image_keys(self) -> list[str]:
@@ -313,6 +313,20 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos.
A depth video key is a feature whose ``info`` dict carries
``"video.is_depth_map": True`` (set either at creation time by the user
or after the first encoded episode by :meth:`update_video_info`).
"""
return [
key
for key, ft in self.features.items()
if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False)
]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
@@ -331,32 +345,32 @@ class LeRobotDatasetMetadata:
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
return self.info.total_episodes
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
return self.info.total_frames
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
return self.info.total_tasks
@property
def chunks_size(self) -> int:
"""Max number of files per chunk."""
return self.info["chunks_size"]
return self.info.chunks_size
@property
def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"]
return self.info.data_files_size_in_mb
@property
def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
return self.info.video_files_size_in_mb
def get_task_index(self, task: str) -> int | None:
"""
@@ -500,29 +514,48 @@ class LeRobotDatasetMetadata:
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["total_tasks"] = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info.total_episodes += 1
self.info.total_frames += episode_length
self.info.total_tasks = len(self.tasks)
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
def update_video_info(
self,
video_key: str | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> None:
"""Populate per-feature video info in ``info.json``.
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
Args:
video_key: If provided, only update this video key. Otherwise update
all video keys in the dataset.
camera_encoder_config: Encoder configuration used to produce the
videos. When provided, its fields are recorded as
``video.<field>`` entries alongside the stream-derived
``video.*`` entries (see :func:`get_video_info`).
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
existing = self.features[key].get("info") or {}
# Repopulate when codec metadata is missing — preserves user-provided
# markers like ``video.is_depth_map`` while still recording stream
# info on the first episode.
if not existing or "video.codec" not in existing:
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config)
merged = {**existing, **stream_info}
self.info.features[key]["info"] = merged
def update_chunk_settings(
self,
@@ -544,17 +577,17 @@ class LeRobotDatasetMetadata:
if chunks_size is not None:
if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size
self.info.chunks_size = chunks_size
if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb
self.info.data_files_size_in_mb = data_files_size_in_mb
if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
self.info.video_files_size_in_mb = video_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
@@ -633,6 +666,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -650,7 +684,7 @@ class LeRobotDatasetMetadata:
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
"Either remove video features from the features dict, or set 'use_videos=True'."
)
write_json(obj.info, obj.root / INFO_PATH)
write_info(obj.info, obj.root)
obj.revision = None
obj._pq_writer = None
obj.latest_episode = None
+33 -8
View File
@@ -32,7 +32,13 @@ from .io_utils import (
hf_transform_to_torch,
load_nested_dataset,
)
from .video_utils import decode_video_frames
from .video_utils import decode_depth_frames, decode_video_frames
from .depth_utils import (
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
)
class DatasetReader:
@@ -237,17 +243,31 @@ class DatasetReader:
"""
ep = self._meta.episodes[ep_idx]
depth_keys = set(self._meta.depth_keys)
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
if vid_key in depth_keys:
feature_info = self._meta.features[vid_key].get("info") or {}
frames = decode_depth_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
)
else:
frames = decode_video_frames(
video_path,
shifted_query_ts,
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())
@@ -295,4 +315,9 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item
+55 -73
View File
@@ -62,7 +62,7 @@ from .utils import (
DEFAULT_EPISODES_PATH,
update_chunk_file_indices,
)
from .video_utils import encode_video_frames, get_video_info
from .video_utils import VideoEncoderConfig, encode_video_frames, get_video_info
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -92,6 +92,7 @@ def delete_episodes(
episode_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> LeRobotDataset:
"""Delete episodes from a LeRobotDataset and create a new dataset.
@@ -100,6 +101,7 @@ def delete_episodes(
episode_indices: List of episode indices to delete.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
"""
if not episode_indices:
raise ValueError("No episodes to delete")
@@ -132,7 +134,7 @@ def delete_episodes(
video_metadata = None
if dataset.meta.video_keys:
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping, camera_encoder_config)
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
@@ -154,6 +156,7 @@ def split_dataset(
dataset: LeRobotDataset,
splits: dict[str, float | list[int]],
output_dir: str | Path | None = None,
camera_encoder_config: VideoEncoderConfig | None = None,
) -> dict[str, LeRobotDataset]:
"""Split a LeRobotDataset into multiple smaller datasets.
@@ -162,6 +165,7 @@ def split_dataset(
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
Examples:
Split by specific episodes
@@ -222,7 +226,9 @@ def split_dataset(
video_metadata = None
if dataset.meta.video_keys:
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
video_metadata = _copy_and_reindex_videos(
dataset, new_meta, episode_mapping, camera_encoder_config
)
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
@@ -578,8 +584,7 @@ def _keep_episodes_from_video_with_av(
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
camera_encoder_config: VideoEncoderConfig | None = None,
) -> None:
"""Keep only specified episodes from a video file using PyAV.
@@ -593,9 +598,10 @@ def _keep_episodes_from_video_with_av(
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
"""
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
from fractions import Fraction
import av
@@ -619,12 +625,12 @@ def _keep_episodes_from_video_with_av(
# Convert fps to Fraction for PyAV compatibility.
fps_fraction = Fraction(fps).limit_denominator(1000)
v_out = out.add_stream(vcodec, rate=fps_fraction)
v_out = out.add_stream(camera_encoder_config.vcodec, rate=fps_fraction)
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
v_out.width = v_in.codec_context.width
v_out.height = v_in.codec_context.height
v_out.pix_fmt = pix_fmt
v_out.pix_fmt = camera_encoder_config.pix_fmt
# Set time_base to match the frame rate for proper timestamp handling.
v_out.time_base = Fraction(1, int(fps))
@@ -687,8 +693,7 @@ def _copy_and_reindex_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_mapping: dict[int, int],
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
camera_encoder_config: VideoEncoderConfig | None = None,
) -> dict[int, dict]:
"""Copy and filter video files, only re-encoding files with deleted episodes.
@@ -700,10 +705,13 @@ def _copy_and_reindex_videos(
src_dataset: Source dataset to copy from
dst_meta: Destination metadata object
episode_mapping: Mapping from old episode indices to new indices
camera_encoder_config: Video encoder settings used when re-encoding segments (default: :class:`VideoEncoderConfig()`).
Returns:
dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
"""
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
@@ -792,8 +800,7 @@ def _copy_and_reindex_videos(
dst_video_path,
episodes_to_keep_ranges,
src_dataset.meta.fps,
vcodec,
pix_fmt,
camera_encoder_config,
)
cumulative_ts = 0.0
@@ -897,14 +904,10 @@ def _copy_and_reindex_episodes_metadata(
dst_meta.finalize()
dst_meta.info.update(
{
"total_episodes": len(episode_mapping),
"total_frames": total_frames,
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
"splits": {"train": f"0:{len(episode_mapping)}"},
}
)
dst_meta.info.total_episodes = len(episode_mapping)
dst_meta.info.total_frames = total_frames
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
write_info(dst_meta.info, dst_meta.root)
if not all_stats:
@@ -1069,21 +1072,20 @@ def _copy_episodes_metadata_and_stats(
if episodes_dir.exists():
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
dst_meta.info.update(
{
"total_episodes": src_dataset.meta.total_episodes,
"total_frames": src_dataset.meta.total_frames,
"total_tasks": src_dataset.meta.total_tasks,
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
}
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
dst_meta.info.total_frames = src_dataset.meta.total_frames
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
# Preserve original splits if available, otherwise create default
dst_meta.info.splits = (
src_dataset.meta.info.splits
if src_dataset.meta.info.splits
else {"train": f"0:{src_dataset.meta.total_episodes}"}
)
if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys:
if key in src_dataset.meta.features:
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
"info", {}
)
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
write_info(dst_meta.info, dst_meta.root)
@@ -1269,11 +1271,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: list[int],
temp_dir: Path,
fps: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
camera_encoder_config: VideoEncoderConfig,
num_calibration_frames: int = 30,
) -> float:
"""Estimate MB per frame by encoding a small calibration sample.
@@ -1287,11 +1285,7 @@ def _estimate_frame_size_via_calibration(
episode_indices: List of episode indices being processed.
temp_dir: Temporary directory for calibration files.
fps: Frames per second for video encoding.
vcodec: Video codec (libsvtav1, h264, hevc).
pix_fmt: Pixel format (yuv420p, etc.).
g: GOP size (group of pictures).
crf: Constant Rate Factor (quality).
fast_decode: Fast decode tuning parameter.
camera_encoder_config: Video encoder settings used for calibration encoding.
num_calibration_frames: Number of frames to use for calibration (default: 30).
Returns:
@@ -1327,11 +1321,7 @@ def _estimate_frame_size_via_calibration(
imgs_dir=calibration_dir,
video_path=calibration_video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
camera_encoder_config=camera_encoder_config,
overwrite=True,
)
@@ -1525,7 +1515,7 @@ def modify_tasks(
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
dataset.meta.info.total_tasks = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
@@ -1649,11 +1639,7 @@ def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path | None = None,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
camera_encoder_config: VideoEncoderConfig | None = None,
episode_indices: list[int] | None = None,
num_workers: int = 4,
max_episodes_per_batch: int | None = None,
@@ -1668,11 +1654,7 @@ def convert_image_to_video_dataset(
dataset: The source LeRobot dataset with images
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
@@ -1681,6 +1663,9 @@ def convert_image_to_video_dataset(
Returns:
New LeRobotDataset with images encoded as videos
"""
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
@@ -1704,7 +1689,10 @@ def convert_image_to_video_dataset(
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
logging.info(
f"Video codec: {camera_encoder_config.vcodec}, pixel format: {camera_encoder_config.pix_fmt}, "
f"GOP: {camera_encoder_config.g}, CRF: {camera_encoder_config.crf}"
)
# Create new features dict, converting image features to video features
new_features = {}
@@ -1774,11 +1762,7 @@ def convert_image_to_video_dataset(
episode_indices=episode_indices,
temp_dir=temp_dir,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
camera_encoder_config=camera_encoder_config,
)
logging.info(f"Processing camera: {img_key}")
@@ -1820,11 +1804,7 @@ def convert_image_to_video_dataset(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
camera_encoder_config=camera_encoder_config,
overwrite=True,
)
@@ -1858,10 +1838,10 @@ def convert_image_to_video_dataset(
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
new_meta.info.total_episodes = len(episode_indices)
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
new_meta.info.total_tasks = dataset.meta.total_tasks
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
@@ -1870,7 +1850,9 @@ def convert_image_to_video_dataset(
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
new_meta.info.features[img_key]["info"] = get_video_info(
video_path, camera_encoder_config=camera_encoder_config
)
write_info(new_meta.info, new_meta.root)
+46 -11
View File
@@ -46,15 +46,19 @@ from .io_utils import (
write_info,
)
from .utils import (
DEFAULT_DEPTH_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
)
from .video_utils import (
DepthEncoderConfig,
StreamingVideoEncoder,
VideoEncoderConfig,
concatenate_video_files,
encode_video_frames,
get_video_duration_in_s,
is_depth_feature,
)
logger = logging.getLogger(__name__)
@@ -65,14 +69,19 @@ def _encode_video_worker(
episode_index: int,
root: Path,
fps: int,
vcodec: str = "libsvtav1",
camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
img_dir,
temp_path,
fps,
camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads,
overwrite=True,
)
shutil.rmtree(img_dir)
return temp_path
@@ -89,33 +98,40 @@ class DatasetWriter:
self,
meta: LeRobotDatasetMetadata,
root: Path,
vcodec: str,
camera_encoder_config: VideoEncoderConfig,
encoder_threads: int | None,
batch_encoding_size: int,
streaming_encoder: StreamingVideoEncoder | None = None,
initial_frames: int = 0,
depth_encoder_config: DepthEncoderConfig | None = None,
):
"""Initialize the writer with metadata, codec, and encoding config.
"""Initialize the writer with metadata, codec, and encoder config.
Args:
meta: Dataset metadata instance (used for feature schema, chunk
settings, and episode persistence).
root: Local dataset root directory.
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
encoder_threads: Threads per encoder instance. ``None`` for auto.
camera_encoder_config: Video encoder settings applied to all cameras.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
for real-time encoding. ``None`` disables streaming mode.
initial_frames: Starting frame count (non-zero when resuming).
depth_encoder_config: Optional depth-map encoder config used in
place of ``camera_encoder_config`` for keys present in
``meta.depth_keys``.
"""
self._meta = meta
self._root = root
self._vcodec = vcodec
self._camera_encoder_config = camera_encoder_config
self._depth_encoder_config = depth_encoder_config
self._encoder_threads = encoder_threads
self._batch_encoding_size = batch_encoding_size
self._streaming_encoder = streaming_encoder
# Writer state
self.image_writer: AsyncImageWriter | None = None
self.episode_buffer: dict = self._create_episode_buffer()
@@ -135,8 +151,16 @@ class DatasetWriter:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _is_depth_image_key(self, image_key: str) -> bool:
"""Whether *image_key* is a depth feature stored as per-frame images."""
ft = self._meta.features.get(image_key)
if ft is None or ft.get("dtype") != "image":
return False
return is_depth_feature(ft.get("info") or {})
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH
fpath = path_template.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
@@ -284,7 +308,7 @@ class DatasetWriter:
episode_index,
self._root,
self._meta.fps,
self._vcodec,
self._camera_encoder_config,
self._encoder_threads,
): video_key
for video_key in self._meta.video_keys
@@ -495,7 +519,13 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key)
is_depth_key = video_key in set(self._meta.depth_keys)
cfg_for_info = (
self._depth_encoder_config
if is_depth_key and self._depth_encoder_config is not None
else self._camera_encoder_config
)
self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info)
write_info(self._meta.info, self._meta.root)
metadata = {
@@ -564,7 +594,12 @@ class DatasetWriter:
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
return _encode_video_worker(
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
video_key,
episode_index,
self._root,
self._meta.fps,
self._camera_encoder_config,
self._encoder_threads,
)
def close_writer(self) -> None:
+189
View File
@@ -0,0 +1,189 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
"""
import math
from typing import Literal
import numpy as np
import torch
from numpy.typing import NDArray
DEPTH_QUANT_BITS: int = 12
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
_MM_PER_METRE: float = 1000.0
_UINT16_MAX: int = 65535
DEFAULT_DEPTH_MIN: float = 0.01
DEFAULT_DEPTH_MAX: float = 10.0
DEFAULT_DEPTH_SHIFT: float = 3.5
DEFAULT_DEPTH_USE_LOG: bool = True
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
"""Ensure ``log(depth_min + shift)`` is finite."""
if depth_min + shift <= 0:
raise ValueError(
f"depth_min + shift must be positive for logarithmic quantization, "
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
)
def _depth_input_to_float32_and_unit(
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
input_unit: Literal["auto", "m", "mm"],
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
"""Depth as float32 in the chosen unit, plus the resolved unit."""
if isinstance(depth, torch.Tensor):
t = depth.detach().cpu()
arr = t.numpy()
is_floating = t.is_floating_point()
else:
arr = np.asarray(depth)
is_floating = np.issubdtype(arr.dtype, np.floating)
resolved_unit: Literal["m", "mm"]
if input_unit == "auto":
resolved_unit = "m" if is_floating else "mm"
else:
resolved_unit = input_unit
# Convert to float32 to keep typing consistency
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
def quantize_depth(
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
*,
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> NDArray[np.uint16]:
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
Depth maps are packed into 12-bit integer frames so they fit in standard
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
Logarithmic quantization is the default because it allocates more quanta
to near-range depth, which matches the (1/depth) error profile of typical
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
**Input units**:
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
- ``input_unit="mm"``: interpret input values as millimetres.
- ``input_unit="m"``: interpret input values as metres.
Quantization math runs in the **resolved input unit**.
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
Args:
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
depth_min: Depth (metres) at quantum ``0``.
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
use_log: If ``True`` (default), quantize in log space.
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
Returns:
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
``[0, DEPTH_QMAX]``.
Raises:
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
"""
if input_unit not in ("auto", "m", "mm"):
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_u + shift_u))
log_max = math.log(float(depth_max_u + shift_u))
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
else:
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX)
return out.astype(np.uint16, copy=False)
def dequantize_depth(
quantized: NDArray[np.uint16] | torch.Tensor,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
*,
output_unit: Literal["m", "mm"] = "mm",
) -> NDArray[np.uint16] | NDArray[np.float32]:
"""Inverse of :func:`quantize_depth`.
Tuning arguments **must match** :func:`quantize_depth`.
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
the requested output unit.
Args:
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
``[depth_min, depth_max]``.
Returns:
Depth map in the requested unit and dtype.
Raises:
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
"""
if output_unit not in ("m", "mm"):
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
if isinstance(quantized, torch.Tensor):
quantized = quantized.detach().cpu().numpy()
q = np.asarray(quantized, dtype=np.uint16, order="K")
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
depth_min_mm = np.float32(depth_min * _MM_PER_METRE)
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
shift_mm = np.float32(shift * _MM_PER_METRE)
if use_log:
_validate_log_quant_params(depth_min, shift)
log_min = math.log(float(depth_min_mm + shift_mm))
log_max = math.log(float(depth_max_mm + shift_mm))
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm
else:
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False)
if output_unit == "m":
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False)
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
return mm.astype(np.uint16, copy=False)
+7 -4
View File
@@ -19,6 +19,7 @@ from pprint import pformat
import torch
from lerobot.configs import PreTrainedConfig
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
Args:
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
``{observation,action,reward}_delta_indices`` properties used below.
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
delta_timestamps against.
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
+33 -37
View File
@@ -22,18 +22,13 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
DatasetInfo,
)
@@ -51,13 +46,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -90,8 +79,8 @@ def create_empty_dataset_info(
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
) -> DatasetInfo:
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
Args:
codebase_version (str): The version of the LeRobot codebase.
@@ -99,25 +88,24 @@ def create_empty_dataset_info(
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
Returns:
dict: A dictionary with the initial dataset metadata.
DatasetInfo: A typed dataset information object with initial metadata.
"""
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
return DatasetInfo(
codebase_version=codebase_version,
fps=fps,
features=features,
robot_type=robot_type,
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
data_path=DEFAULT_DATA_PATH,
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
)
def check_delta_timestamps(
@@ -254,8 +242,6 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
@@ -308,10 +294,20 @@ def validate_feature_image_or_video(
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
actual_shape = tuple(value.shape)
expected = tuple(expected_shape)
if len(expected) == 2:
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
h, w = expected
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
if not valid:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
elif len(expected) == 3:
c, h, w = expected
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
else:
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
+63 -7
View File
@@ -41,15 +41,56 @@ def safe_stop_image_writer(func):
return wrapper
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
# Single-channel dtypes that PIL natively maps to the matching mode
# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``).
GRAYSCALE_DTYPES: tuple[np.dtype, ...] = (
np.dtype("uint8"),
np.dtype("uint16"),
np.dtype("float32"),
)
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
if image_array.ndim not in (2, 3):
raise ValueError(
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
)
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in GRAYSCALE_DTYPES:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
@@ -71,13 +112,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (09, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
"""
Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for
the save operation.
the save operation. The output format is inferred from the *fpath*
extension: ``.png`` PNG with ``compress_level``, ``.tiff`` / ``.tif``
lossless raw depth maps (TIFF).
Args:
image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -101,7 +157,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level))
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)
+18 -18
View File
@@ -34,10 +34,12 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
STATS_PATH,
DatasetInfo,
serialize_dict,
)
@@ -114,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset
def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
def write_info(info: DatasetInfo, local_dir: Path) -> None:
write_json(info.to_dict(), local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
def load_info(local_dir: Path) -> DatasetInfo:
"""Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: The dataset information dictionary.
DatasetInfo: The typed dataset information object.
"""
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
raw = load_json(local_dir / INFO_PATH)
return DatasetInfo.from_dict(raw)
def write_stats(stats: dict, local_dir: Path) -> None:
@@ -188,6 +186,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -259,13 +265,11 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in {"language_persistent", "language_events"}:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None or isinstance(first_item, dict):
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -301,11 +305,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in [
"task",
"language_persistent",
"language_events",
]:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
-150
View File
@@ -1,150 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
EXTENDED_STYLES = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")
-511
View File
@@ -1,511 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
)
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
selector. ``events`` is accepted for resolver-signature uniformity but is
not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
matches = [row for row in matches if _timestamp(row) <= t]
return _select_latest(matches, style=style, role=role, tool_name=tool_name)
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``.
"""
column = column_for_style(style)
if column == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
if _timestamp(row) == t
]
return _select_one(
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
)
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions before the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=-offset,
role=role,
tool_name=tool_name,
resolver_name="nth_prev",
)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions after the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=offset,
role=role,
tool_name=tool_name,
resolver_name="nth_next",
)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``."""
if task is not None:
return task
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
resolvers = {
"active_at": active_at,
"emitted_at": emitted_at,
"nth_prev": nth_prev,
"nth_next": nth_next,
}
if name not in resolvers:
raise ValueError(f"Unknown language resolver: {name!r}")
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return _PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
for idx, stream in enumerate(streams):
if stream is None:
raise ValueError(f"Rendered message {idx} has no stream.")
def _nth_relative(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
resolver_name: str,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(resolver_name, style)
if abs(offset) < 1:
raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
key=_persistent_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{resolver_name} requires a persistent style.")
if style in EVENT_ONLY_STYLES:
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
if style not in PERSISTENT_STYLES:
column_for_style(style)
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
]
def _select_latest(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
) -> LanguageRow | None:
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
if not rows:
return None
rows = sorted(rows, key=_persistent_sort_key)
latest_ts = _timestamp(rows[-1])
return _select_one(
[row for row in rows if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
sort_key=_persistent_sort_key,
)
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
sort_key: Any,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the selectors are ambiguous."""
if not rows:
return None
if len(rows) > 1 and role is None and tool_name is None:
raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
)
return sorted(rows, key=sort_key)[0]
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
return (row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized
+87 -44
View File
@@ -35,9 +35,11 @@ from .utils import (
is_valid_version,
)
from .video_utils import (
DepthEncoderConfig,
StreamingVideoEncoder,
get_safe_default_codec,
resolve_vcodec,
VideoEncoderConfig,
get_safe_default_video_backend,
seed_depth_feature_info,
)
logger = logging.getLogger(__name__)
@@ -58,10 +60,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: str | None = None,
return_uint8: bool = False,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -177,16 +180,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
camera_encoder_config (VideoEncoderConfig | None, optional): Video encoder settings for cameras
(codec, quality, etc.). Defaults to
:class:`~lerobot.datasets.video_utils.VideoEncoderConfig` defaults when ``None``.
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
codec decide.
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
streaming encoding. Defaults to 30 (~1s at 30fps).
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
libsvtav1 and 'threads' for h264/hevc.
Note:
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
@@ -202,10 +204,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
self._return_uint8 = return_uint8
self._batch_encoding_size = batch_encoding_size
self._vcodec = resolve_vcodec(vcodec)
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
self._camera_encoder_config = camera_encoder_config
self._depth_encoder_config = depth_encoder_config
self._encoder_threads = encoder_threads
if self._requested_root is not None:
@@ -248,16 +253,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
DeprecationWarning,
stacklevel=2,
)
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
streaming_enc = None
if streaming_encoding and len(self.meta.video_keys) > 0:
streaming_enc = self._build_streaming_encoder(
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
self.meta.fps,
self._camera_encoder_config,
self._encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=self._depth_encoder_config,
depth_keys=self.meta.depth_keys,
)
self.writer = DatasetWriter(
meta=self.meta,
root=self.root,
vcodec=self._vcodec,
encoder_threads=encoder_threads,
camera_encoder_config=self._camera_encoder_config,
depth_encoder_config=self._depth_encoder_config,
encoder_threads=self._encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
initial_frames=self.meta.total_frames,
@@ -298,19 +310,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@staticmethod
def _build_streaming_encoder(
fps: int,
vcodec: str,
encoder_queue_maxsize: int,
camera_encoder_config: VideoEncoderConfig,
encoder_threads: int | None,
encoder_queue_maxsize: int,
*,
depth_encoder_config: DepthEncoderConfig | None = None,
depth_keys: list[str] | None = None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
vcodec=vcodec,
pix_fmt="yuv420p",
g=2,
crf=30,
preset=None,
queue_maxsize=encoder_queue_maxsize,
camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads,
queue_maxsize=encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=depth_keys,
)
# ── Metadata properties ───────────────────────────────────────────
@@ -625,11 +638,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
metadata_buffer_size: int = 10,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
video_files_size_in_mb: int | None = None,
data_files_size_in_mb: int | None = None,
) -> "LeRobotDataset":
"""Create a new LeRobotDataset from scratch for recording data.
@@ -654,20 +670,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend (used when reading back).
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos. ``1`` means encode immediately.
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
``'h264'``, ``'hevc'``, ``'auto'``.
camera_encoder_config: Video encoder settings for cameras; defaults
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
when ``None``.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
metadata_buffer_size: Number of episode metadata records to buffer
before flushing to parquet.
streaming_encoding: If ``True``, encode video frames in real-time
during capture instead of writing images first.
encoder_queue_maxsize: Max buffered frames per camera when using
streaming encoding.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A new :class:`LeRobotDataset` in write mode.
"""
vcodec = resolve_vcodec(vcodec)
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
obj = cls.__new__(cls)
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
@@ -677,6 +696,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
root=root,
use_videos=use_videos,
metadata_buffer_size=metadata_buffer_size,
video_files_size_in_mb=video_files_size_in_mb,
data_files_size_in_mb=data_files_size_in_mb,
)
obj.repo_id = obj.meta.repo_id
obj._requested_root = obj.meta.root
@@ -686,23 +707,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._camera_encoder_config = camera_encoder_config
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
streaming_enc = cls._build_streaming_encoder(
fps,
camera_encoder_config,
encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=obj.meta.depth_keys,
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
vcodec=vcodec,
camera_encoder_config=camera_encoder_config,
depth_encoder_config=depth_encoder_config,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
@@ -725,12 +755,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
video_backend: str | None = None,
batch_encoding_size: int = 1,
vcodec: str = "libsvtav1",
camera_encoder_config: VideoEncoderConfig | None = None,
depth_encoder_config: DepthEncoderConfig | None = None,
encoder_threads: int | None = None,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
streaming_encoding: bool = False,
encoder_queue_maxsize: int = 30,
encoder_threads: int | None = None,
) -> "LeRobotDataset":
"""Resume recording on an existing dataset.
@@ -753,13 +784,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_backend: Video decoding backend for reading back data.
batch_encoding_size: Number of episodes to accumulate before
batch-encoding videos.
vcodec: Video codec for encoding.
camera_encoder_config: Video encoder settings for cameras; defaults
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
when ``None``.
encoder_threads: Number of encoder threads (global). ``None``
lets the codec decide.
image_writer_processes: Subprocesses for async image writing.
image_writer_threads: Threads for async image writing.
streaming_encoding: If ``True``, encode video in real-time during
capture.
encoder_queue_maxsize: Max buffered frames per camera for streaming.
encoder_threads: Threads per encoder instance. ``None`` for auto.
Returns:
A :class:`LeRobotDataset` in write mode, ready to append episodes.
@@ -770,7 +804,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
"the shared cache. Please provide a local directory path."
)
vcodec = resolve_vcodec(vcodec)
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj._requested_root = Path(root)
@@ -779,11 +812,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.episodes = None
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
obj._return_uint8 = False
obj._batch_encoding_size = batch_encoding_size
obj._vcodec = vcodec
obj._encoder_threads = encoder_threads
if obj._requested_root is not None:
obj._requested_root.mkdir(exist_ok=True, parents=True)
@@ -792,21 +823,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata(
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
)
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
obj._camera_encoder_config = camera_encoder_config
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
# Reader is lazily created on first access (write-only mode)
obj.reader = None
# Create writer for appending
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
obj.meta.fps,
camera_encoder_config,
encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=obj.meta.depth_keys,
)
obj.writer = DatasetWriter(
meta=obj.meta,
root=obj.root,
vcodec=vcodec,
camera_encoder_config=camera_encoder_config,
depth_encoder_config=depth_encoder_config,
encoder_threads=encoder_threads,
batch_encoding_size=batch_encoding_size,
streaming_encoder=streaming_enc,
+2 -2
View File
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
return self._datasets[0].meta.info.fps
@property
def video(self) -> bool:
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
return len(self._datasets[0].meta.video_keys) > 0
@property
def features(self) -> datasets.Features:
+311
View File
@@ -0,0 +1,311 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
Checks degrade to a no-op when the target codec isn't available locally.
"""
from __future__ import annotations
import functools
import logging
from typing import TYPE_CHECKING, Any, Literal
import av
import numpy as np
import torch
from lerobot.datasets.depth_utils import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
quantize_depth,
dequantize_depth,
)
if TYPE_CHECKING:
from lerobot.datasets.video_utils import VideoEncoderConfig
logger = logging.getLogger(__name__)
# Pixel formats supported by the depth encode/decode helpers below. Both are
# 16-bit-word formats that carry 12 significant bits per sample, matching the
# ``DEPTH_QMAX = 4095`` quantization range.
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
# planes with this value keeps the encoder from spending bits on chroma noise
# when only the Y plane carries information.
_NEUTRAL_CHROMA_12BIT: int = 2048
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
def encode_depth_frame_pyav(
depth: np.ndarray | torch.Tensor,
*,
pix_fmt: str = "yuv420p12le",
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> av.VideoFrame:
"""Quantize depth and pack it into a 12-bit PyAV video frame.
Args:
depth: Depth frame to encode (H, W). Unit handling follows
:func:`lerobot.datasets.depth_utils.quantize_depth`.
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
:func:`quantize_depth`.
Returns:
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
luminance plane.
"""
if pix_fmt not in DEPTH_PIX_FMTS:
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
quantized_depth = quantize_depth(
depth,
depth_min=depth_min,
depth_max=depth_max,
shift=shift,
use_log=use_log,
input_unit=input_unit,
)
if quantized_depth.ndim != 2:
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
height, width = quantized_depth.shape
if pix_fmt == "gray12le":
frame = av.VideoFrame(width=width, height=height, format="gray12le")
_write_u16_plane(frame.planes[0], quantized_depth)
return frame
if height % 2 != 0 or width % 2 != 0:
raise ValueError("yuv420p12le requires even H and W")
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
_write_u16_plane(frame.planes[0], quantized_depth)
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
return frame
def decode_depth_frame_pyav(
frame: av.VideoFrame | list[av.VideoFrame],
*,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
return_quantized: bool = False,
output_unit: Literal["m", "mm"] = "m",
) -> np.ndarray:
"""Decode one or many depth video frames to quantized or metric depth.
Args:
frame: A single depth frame or a list of depth frames.
depth_min, depth_max, shift, use_log: Forwarded to
:func:`dequantize_depth`.
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
Returns:
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
"""
frames = frame if isinstance(frame, list) else [frame]
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
if return_quantized:
return quantized[0] if len(frames) == 1 else quantized
decoded = dequantize_depth(
quantized,
depth_min=depth_min,
depth_max=depth_max,
shift=shift,
use_log=use_log,
output_unit=output_unit,
)
return decoded[0] if len(frames) == 1 else decoded
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
try:
return av.codec.Codec(vcodec, "w")
except Exception:
return None
@functools.cache
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
codec = get_codec(vcodec)
if codec is None:
return {}
return {opt.name: opt for opt in codec.descriptor.options}
@functools.cache
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
codec = get_codec(vcodec)
if codec is None:
return ()
return tuple(fmt.name for fmt in (codec.video_formats or []))
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
Each name is probed directly via :func:`get_codec`; input order is preserved.
"""
if isinstance(encoders, str):
encoders = [encoders]
available: list[str] = []
for name in encoders:
codec = get_codec(name)
if codec is not None and codec.type == "video":
available.append(name)
else:
logger.debug("encoder '%s' not available as video encoder", name)
return available
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
type_name = opt.type.name
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
if isinstance(value, bool):
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
elif isinstance(value, str):
try:
num_val = float(value)
except ValueError as e:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
) from e
elif isinstance(value, (float, int)):
num_val = value
else:
raise ValueError(
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
)
# Check integer type compatibility
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
raise ValueError(
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
)
# Check numeric range compatibility
lo, hi = float(opt.min), float(opt.max)
if lo < hi and not (lo <= num_val <= hi):
raise ValueError(
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
)
elif type_name == "STRING":
if isinstance(value, bool):
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
if isinstance(value, str):
str_val = value
elif isinstance(value, (int, float)):
str_val = str(value)
else:
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
# Check string choice compatibility
choices = [c.name for c in (opt.choices or [])]
if choices and str_val not in choices:
raise ValueError(
f"{label}={str_val!r} is not a supported choice for codec "
f"{vcodec!r}; valid choices: {choices}"
)
else:
return
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
formats = _get_codec_video_formats(vcodec)
if formats and pix_fmt not in formats:
raise ValueError(
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
f"supported pixel formats: {list(formats)}"
)
def _check_codec_options(vcodec: str, codec_options: dict[str, Any], config: VideoEncoderConfig) -> None:
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
supported_options = _get_codec_options_by_name(vcodec)
for key, value in codec_options.items():
# GOP size is not a codec-specific option, it has to be validated separately.
if key == "g":
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
continue
if key not in supported_options:
continue
opt = supported_options[key]
label = f"extra_options[{key!r}]" if key in config.extra_options else key
_check_option_value(vcodec, label, value, opt)
def check_video_encoder_config_pyav(config: VideoEncoderConfig) -> None:
"""Verify *config* is compatible with the bundled FFmpeg build.
Checks pixel format, abstract tuning-field compatibility, and each merged
encoder option from :meth:`~lerobot.datasets.video_utils.VideoEncoderConfig.get_codec_options`
against PyAV (including numeric ``extra_options`` present in that dict).
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
Raises:
ValueError: on the first incompatibility encountered.
"""
vcodec = config.vcodec
options = _get_codec_options_by_name(vcodec)
if not options:
logger.warning(
"Codec %r is not available in the bundled FFmpeg build; ",
vcodec,
)
return
_check_pixel_format(config.vcodec, config.pix_fmt)
_check_codec_options(config.vcodec, config.get_codec_options(), config)
+1 -1
View File
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
def _make_padding_camera_frame(self, camera_key: str):
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
def _get_video_frame_padding_mask(
self,
+130 -3
View File
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import dataclasses
import importlib.resources
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
import datasets
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
super().__init__(message)
logger = logging.getLogger(__name__)
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
@@ -83,16 +88,138 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG
# cannot natively round-trip float32, and several common loaders silently
# downcast 16-bit grayscale.
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
@dataclass
class DatasetInfo:
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
created by ``create_empty_dataset_info()``. Using a dataclass provides
explicit field definitions, IDE auto-completion, and validation at
construction time.
"""
codebase_version: str
fps: int
features: dict[str, dict]
# Episode / frame counters — start at zero for new datasets
total_episodes: int = 0
total_frames: int = 0
total_tasks: int = 0
# Storage settings
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
# File path templates
data_path: str = field(default=DEFAULT_DATA_PATH)
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
# returns lists, but the rest of the codebase expects tuples.
for ft in self.features.values():
if isinstance(ft.get("shape"), list):
ft["shape"] = tuple(ft["shape"])
if self.fps <= 0:
raise ValueError(f"fps must be positive, got {self.fps}")
if self.chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
if self.data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
if self.video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
def to_dict(self) -> dict:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
return d
@classmethod
def from_dict(cls, data: dict) -> "DatasetInfo":
"""Construct from a raw dict (e.g. loaded directly from JSON).
Unknown keys are ignored for forward compatibility with datasets that
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
logged when such fields are present.
"""
known = {f.name for f in dataclasses.fields(cls)}
unknown = sorted(k for k in data if k not in known)
if unknown:
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
return cls(**{k: v for k, v in data.items() if k in known})
# ---------------------------------------------------------------------------
# Temporary dict-style compatibility layer
# Allows existing ``info["key"]`` call-sites to keep working without changes.
# Once all callers have been migrated to attribute access, remove these.
# ---------------------------------------------------------------------------
def __getitem__(self, key: str):
import warnings
warnings.warn(
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
f"Use attribute access info.{key} instead.",
DeprecationWarning,
stacklevel=2,
)
try:
return getattr(self, key)
except AttributeError as err:
raise KeyError(key) from err
def __setitem__(self, key: str, value) -> None:
import warnings
warnings.warn(
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
f"Use attribute assignment info.{key} = ... instead.",
DeprecationWarning,
stacklevel=2,
)
if not hasattr(self, key):
raise KeyError(f"DatasetInfo has no field '{key}'")
setattr(self, key, value)
def __contains__(self, key: str) -> bool:
"""Check if a field exists (dict-like interface)."""
return hasattr(self, key)
def get(self, key: str, default=None):
"""Get attribute value with default fallback (dict-like interface)."""
try:
return getattr(self, key)
except AttributeError:
return default
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
@@ -293,7 +420,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
def create_lerobot_dataset_card(
tags: list | None = None,
dataset_info: dict | None = None,
dataset_info: DatasetInfo | None = None,
**kwargs,
) -> DatasetCard:
"""Create a `DatasetCard` for a LeRobot dataset.
@@ -304,7 +431,7 @@ def create_lerobot_dataset_card(
Args:
tags (list | None): A list of tags to add to the dataset card.
dataset_info (dict | None): The dataset's info dictionary, which will
dataset_info (DatasetInfo | None): The dataset's info object, which will
be displayed on the card.
**kwargs: Additional keyword arguments to populate the card template.
@@ -317,7 +444,7 @@ def create_lerobot_dataset_card(
card_tags += tags
if dataset_info:
dataset_structure = "[meta/info.json](meta/info.json):\n"
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
kwargs = {**kwargs, "dataset_structure": dataset_structure}
card_data = DatasetCardData(
license=kwargs.get("license"),
+517 -145
View File
@@ -17,12 +17,13 @@ import contextlib
import glob
import importlib
import logging
import math
import queue
import shutil
import tempfile
import threading
import warnings
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from fractions import Fraction
from pathlib import Path
from threading import Lock
@@ -37,7 +38,23 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
from lerobot.utils.import_utils import get_safe_default_codec
from lerobot.datasets.pyav_utils import (
check_video_encoder_config_pyav,
depth_to_video_frame,
detect_available_encoders_pyav,
decode_depth_frame,
encode_depth_frame_pyav,
decode_depth_frame_pyav,
)
from lerobot.datasets.depth_utils import (
quantize_depth,
dequantize_depth,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
)
from lerobot.utils.import_utils import get_safe_default_video_backend
logger = logging.getLogger(__name__)
@@ -52,70 +69,226 @@ HW_ENCODERS = [
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "ffv1", "auto"} | set(HW_ENCODERS)
LIBSVTAV1_DEFAULT_PRESET: int = 12
def _get_codec_options(
vcodec: str,
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
) -> dict:
"""Build codec-specific options dict for video encoding."""
options = {}
@dataclass
class VideoEncoderConfig:
"""Video encoder configuration.
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
options["g"] = str(g)
Attributes:
vcodec: FFmpeg encoder name. ``"auto"`` is resolved during
construction (HW encoder if available, else ``libsvtav1``).
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
g: GOP size (keyframe interval).
crf: Quality level mapped to the native quality parameter of the
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
preset: Speed/quality preset. Accepted type is per-codec.
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
set ``tune=fastdecode``. Ignored for other codecs.
video_backend: Python library driving FFmpeg for encoding. Only ``"pyav"``
is currently supported.
extra_options: Free-form dictionary of additional FFmpeg options
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
"""
# Quality control (codec-specific parameter names)
if crf is not None:
if vcodec in ("h264", "hevc", "libsvtav1"):
options["crf"] = str(crf)
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
quality = max(1, min(100, int(100 - crf * 2)))
options["q:v"] = str(quality)
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
options["rc"] = "constqp"
options["qp"] = str(crf)
elif vcodec in ("h264_vaapi",):
options["qp"] = str(crf)
elif vcodec in ("h264_qsv",):
options["global_quality"] = str(crf)
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
g: int | None = 2
crf: int | None = 30
preset: int | str | None = None
fast_decode: int = 0
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
# two backends (encoding and decoding).
video_backend: str = "pyav"
extra_options: dict[str, Any] = field(default_factory=dict)
# Preset (only for libsvtav1)
if vcodec == "libsvtav1":
options["preset"] = str(preset) if preset is not None else "12"
# Class-level marker persisted to ``info.json`` (via ``asdict``) so the
# reader can tell depth datasets from RGB ones without a separate dispatch
# path. ``init=False`` keeps it out of CLI/constructor surface; subclasses
# flip the default (see :class:`DepthEncoderConfig`).
is_depth_map: bool = field(default=False, init=False)
return options
def __post_init__(self) -> None:
self.resolve_vcodec()
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
if self.preset is None and self.vcodec == "libsvtav1":
self.preset = LIBSVTAV1_DEFAULT_PRESET
self.validate()
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
"""Detect available encoders based on the video backend."""
if self.video_backend == "pyav":
return detect_available_encoders_pyav(encoders)
else:
return []
def validate(self) -> None:
"""Validate the video encoder config."""
if self.video_backend == "pyav":
check_video_encoder_config_pyav(self)
def resolve_vcodec(self) -> None:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.
Any explicitly-requested codec that isn't in the local FFmpeg build is
also silently rewritten to ``libsvtav1`` so encoding never hard-fails on
a host missing the requested encoder.
"""
# Backward compatibility: older datasets persist ``vcodec="av1"`` in
# ``info.json``. Rewrite to the canonical encoder name *before* the
# validation check below so loading those datasets keeps working.
if self.vcodec == "av1":
self.vcodec = "libsvtav1"
if self.vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if self.vcodec == "auto":
available = self.detect_available_encoders(HW_ENCODERS)
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
self.vcodec = encoder
return
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
self.vcodec = "libsvtav1"
if self.detect_available_encoders(self.vcodec):
logger.info(f"Using video codec: {self.vcodec}")
self.vcodec = self.vcodec
return
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
def get_codec_options(
self, encoder_threads: int | None = None, as_strings: bool = False
) -> dict[str, str]:
"""Translate the tuning fields to codec-specific FFmpeg options.
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
Args:
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
For h264/hevc, this is mapped to ``threads``.
Hardware encoders ignore this parameter.
as_strings: If ``True``, casts values to strings.
"""
opts: dict[str, Any] = {}
def set_if(key: str, value: Any) -> None:
if value is not None:
opts[key] = value if not as_strings else str(value)
# GOP size is not a codec-specific option, so it is always set.
set_if("g", self.g)
if self.vcodec == "libsvtav1":
set_if("crf", self.crf)
set_if("preset", self.preset)
svtav1_parts: list[str] = []
if self.fast_decode is not None:
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
if encoder_threads is not None:
svtav1_parts.append(f"lp={encoder_threads}")
if svtav1_parts:
opts["svtav1-params"] = ":".join(svtav1_parts)
elif self.vcodec in ("h264", "hevc"):
set_if("crf", self.crf)
set_if("preset", self.preset)
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
opts["rc"] = "constqp"
set_if("qp", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "h264_vaapi":
set_if("qp", self.crf)
elif self.vcodec == "h264_qsv":
set_if("global_quality", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "ffv1":
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
# are not meaningful.
set_if("threads", encoder_threads)
else:
set_if("crf", self.crf)
set_if("preset", self.preset)
# Extra options are merged last but never override structured fields (values are kept as given).
for k, v in self.extra_options.items():
if k not in opts:
set_if(k, v)
return opts
def detect_available_hw_encoders() -> list[str]:
"""Probe PyAV/FFmpeg for available hardware video encoders."""
available = []
for codec_name in HW_ENCODERS:
try:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
return available
@dataclass
class DepthEncoderConfig(VideoEncoderConfig):
"""Encoder configuration for depth-map streams.
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
preset, ``extra_options``) and adds the four parameters of the depth
quantization pipeline (:func:`quantize_depth`). Inheritance rather
than composition keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
works identically to its RGB counterpart.
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
``ffv1`` doesn't expose those tuning knobs).
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
so it's hidden from CLI and constructor args) and is what the reader
side keys on to tell depth datasets from RGB ones.
Attributes:
depth_min: Minimum depth in physical units (e.g. metres) represented
by quantum ``0``.
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
shift: Pre-log offset for numerical stability near zero.
use_log: ``True`` for logarithmic quantization (default; matches
sensor error profile), ``False`` for linear.
"""
vcodec: str = "hevc"
pix_fmt: str = "yuv420p12le"
depth_min: float = DEFAULT_DEPTH_MIN
depth_max: float = DEFAULT_DEPTH_MAX
shift: float = DEFAULT_DEPTH_SHIFT
use_log: bool = DEFAULT_DEPTH_USE_LOG
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
# via ``asdict`` into ``info.json`` for the reader to detect depth.
is_depth_map: bool = field(default=True, init=False)
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
"""Apply :func:`quantize_depth` bound to this config's parameters."""
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
def resolve_vcodec(vcodec: str) -> str:
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
def depth_encoder_defaults() -> DepthEncoderConfig:
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
return DepthEncoderConfig()
def camera_encoder_defaults() -> VideoEncoderConfig:
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
return VideoEncoderConfig()
def decode_video_frames(
@@ -142,7 +315,7 @@ def decode_video_frames(
Currently supports torchcodec on cpu and pyav.
"""
if backend is None:
backend = get_safe_default_codec()
backend = get_safe_default_video_backend()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]:
@@ -396,22 +569,136 @@ def decode_video_frames_torchcodec(
return closest_frames
def decode_depth_frames(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
*,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
return_quantized: bool = False,
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Decode depth-map frames at the requested timestamps using PyAV.
Mirrors the timestamp-tolerance / closest-frame contract of
:func:`decode_video_frames` but operates entirely through PyAV (the
``torchvision`` and ``torchcodec`` backends don't currently round-trip
12-bit pixel formats reliably).
Each decoded frame is reformatted to ``gray12le`` so the same path
handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive)
sources transparently.
Args:
video_path: Path to a depth video produced with a
:class:`DepthEncoderConfig`.
timestamps: Frame timestamps to retrieve, in seconds.
tolerance_s: Maximum allowed deviation between the queried and the
actually-decoded timestamps.
depth_min, depth_max, shift, use_log: Parameters used at quantization
time. Should match :func:`info_to_depth_kwargs` extracted from
``info.json`` for the source dataset.
return_quantized: If ``True``, skip the dequantization step and
return raw 12-bit ``uint16`` quanta.
log_loaded_timestamps: Debug logging.
Returns:
``torch.Tensor`` of shape ``(N, H, W)``:
* ``dtype=torch.float32`` (metric depth, default)
* ``dtype=torch.uint16`` when ``return_quantized=True``.
Raises:
FrameTimestampError: If a query timestamp can't be matched within
*tolerance_s*, or if no frames are decoded.
"""
video_path_str = str(video_path)
first_ts = min(timestamps)
last_ts = max(timestamps)
loaded_frames: list[np.ndarray] = []
loaded_ts: list[float] = []
av.logging.set_level(av.logging.WARNING)
with av.open(video_path_str, "r") as container:
try:
stream = container.streams.video[0]
except IndexError as e:
raise FrameTimestampError(f"No video stream in {video_path_str}") from e
# Seek to the keyframe at-or-before first_ts (PyAV doesn't do
# accurate seek, so we still iterate forward to the requested range).
seek_pts = int(first_ts / stream.time_base)
container.seek(seek_pts, stream=stream, any_frame=False, backward=True)
for frame in container.decode(stream):
if frame.pts is None:
continue
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"depth frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(
decode_depth_frame(
frame,
depth_min=depth_min,
depth_max=depth_max,
shift=shift,
use_log=use_log,
return_quantized=True,
)
)
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
av.logging.restore_default_callback()
if not loaded_frames:
raise FrameTimestampError(
f"No depth frames decoded from {video_path_str} for timestamps {timestamps}"
)
query_ts = torch.tensor(timestamps)
loaded_ts_t = torch.tensor(loaded_ts)
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps violate the tolerance "
f"({min_[~is_within_tol]} > {tolerance_s=})."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts_t}"
f"\nvideo: {video_path_str}"
)
closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16
quantized = torch.from_numpy(closest)
if return_quantized:
return quantized
return dequantize_depth(quantized, depth_min, depth_max, shift, use_log)
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
fast_decode: int = 0,
camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
preset: int | None = None,
encoder_threads: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
vcodec = resolve_vcodec(vcodec)
if camera_encoder_config is None:
camera_encoder_config = VideoEncoderConfig()
vcodec = camera_encoder_config.vcodec
pix_fmt = camera_encoder_config.pix_fmt
video_path = Path(video_path)
imgs_dir = Path(imgs_dir)
@@ -422,42 +709,18 @@ def encode_video_frames(
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
# Get input frames
template = "frame-" + ("[0-9]" * 6) + ".png"
input_list = sorted(
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
)
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
# Define video codec options
video_options = _get_codec_options(vcodec, g, crf, preset)
if fast_decode:
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if encoder_threads is not None:
if vcodec == "libsvtav1":
lp_param = f"lp={encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(encoder_threads)
video_options = camera_encoder_config.get_codec_options(encoder_threads, as_strings=True)
# Set logging level
if log_level is not None:
@@ -494,7 +757,10 @@ def encode_video_frames(
def concatenate_video_files(
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
input_video_paths: list[Path | str],
output_video_path: Path,
overwrite: bool = True,
compatibility_check: bool = False,
):
"""
Concatenate multiple video files into a single video file using pyav.
@@ -507,6 +773,7 @@ def concatenate_video_files(
input_video_paths: Ordered list of input video file paths to concatenate.
output_video_path: Path to the output video file.
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
compatibility_check: Whether to check if the input videos are compatible. Default is False.
Note:
- Creates a temporary directory for intermediate files that is cleaned up after use.
@@ -525,6 +792,22 @@ def concatenate_video_files(
if len(input_video_paths) == 0:
raise FileNotFoundError("No input video paths provided.")
# This check may be skipped at recording time as videos are encoded with the same encoder config.
if compatibility_check:
reference_video_info = get_video_info(input_video_paths[0])
for input_path in input_video_paths[1:]:
video_info = get_video_info(input_path)
if (
video_info["video.height"] != reference_video_info["video.height"]
or video_info["video.width"] != reference_video_info["video.width"]
or video_info["video.fps"] != reference_video_info["video.fps"]
or video_info["video.codec"] != reference_video_info["video.codec"]
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
):
raise ValueError(
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
)
# Create a temporary .ffconcat file to list the input video paths
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
@@ -591,33 +874,31 @@ class _CameraEncoderThread(threading.Thread):
fps: int,
vcodec: str,
pix_fmt: str,
g: int | None,
crf: int | None,
preset: int | None,
codec_options: dict[str, str],
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
encoder_threads: int | None = None,
depth_encoder_config: "DepthEncoderConfig | None" = None,
):
super().__init__(daemon=True)
self.video_path = video_path
self.fps = fps
self.vcodec = vcodec
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self.codec_options = codec_options
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.encoder_threads = encoder_threads
self.depth_encoder_config = depth_encoder_config
def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
container = None
output_stream = None
stats_tracker = RunningQuantileStats()
is_depth = self.depth_encoder_config is not None
stats_tracker = RunningQuantileStats() if not is_depth else None
frame_count = 0
try:
@@ -635,51 +916,45 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
if frame_data.dtype != np.uint8 and not is_depth:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
if container is None:
height, width = frame_data.shape[:2]
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
if self.encoder_threads is not None:
if self.vcodec == "libsvtav1":
lp_param = f"lp={self.encoder_threads}"
if "svtav1-params" in video_options:
video_options["svtav1-params"] += f":{lp_param}"
else:
video_options["svtav1-params"] = lp_param
else:
video_options["threads"] = str(self.encoder_threads)
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
container = av.open(str(self.video_path), "w")
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
output_stream.pix_fmt = self.pix_fmt
output_stream.width = width
output_stream.height = height
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
if is_depth:
video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log)
else:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
if packet:
container.mux(packet)
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
if not is_depth:
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
frame_count += 1
@@ -694,8 +969,10 @@ class _CameraEncoderThread(threading.Thread):
av.logging.restore_default_callback()
# Get stats and put on result queue
if frame_count >= 2:
# Get stats and put on result queue (depth streams skip stats)
if is_depth:
self.result_queue.put(("ok", None))
elif frame_count >= 2:
stats = stats_tracker.get_statistics()
self.result_queue.put(("ok", stats))
else:
@@ -724,22 +1001,40 @@ class StreamingVideoEncoder:
def __init__(
self,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,
preset: int | None = None,
queue_maxsize: int = 30,
camera_encoder_config: VideoEncoderConfig | None = None,
encoder_threads: int | None = None,
*,
queue_maxsize: int = 30,
depth_encoder_config: "DepthEncoderConfig | None" = None,
depth_keys: list[str] | None = None,
):
"""
Args:
fps: Frames per second for the output videos.
camera_encoder_config: Video encoder settings applied to all cameras.
When ``None``, :class:`VideoEncoderConfig` defaults are used.
encoder_threads: Number of encoder threads (global setting).
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
depth_encoder_config: Optional depth encoder configuration applied
to all depth video keys listed in ``depth_keys``.
depth_keys: Video keys (matching the dataset feature names) that
must be encoded as quantized depth maps using
``depth_encoder_config``. Required when ``depth_encoder_config``
is provided.
"""
self.fps = fps
self.vcodec = resolve_vcodec(vcodec)
self.pix_fmt = pix_fmt
self.g = g
self.crf = crf
self.preset = preset
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize
self.encoder_threads = encoder_threads
self._depth_encoder_config = depth_encoder_config
self._depth_keys: set[str] = set(depth_keys or [])
if self._depth_keys and self._depth_encoder_config is None:
raise ValueError(
"StreamingVideoEncoder received depth_keys without a depth_encoder_config; "
"either pass a DepthEncoderConfig or remove depth_keys."
)
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
@@ -770,18 +1065,28 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
is_depth_key = video_key in self._depth_keys
encoder_cfg: VideoEncoderConfig
depth_cfg = None
if is_depth_key:
assert self._depth_encoder_config is not None # guaranteed by __init__
encoder_cfg = self._depth_encoder_config
depth_cfg = self._depth_encoder_config
else:
encoder_cfg = self._camera_encoder_config
vcodec = encoder_cfg.vcodec
codec_options = encoder_cfg.get_codec_options(self._encoder_threads)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=self.vcodec,
pix_fmt=self.pix_fmt,
g=self.g,
crf=self.crf,
preset=self.preset,
vcodec=vcodec,
pix_fmt=encoder_cfg.pix_fmt,
codec_options=codec_options,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
encoder_threads=self.encoder_threads,
depth_encoder_config=depth_cfg,
)
encoder_thread.start()
@@ -986,8 +1291,18 @@ def get_audio_info(video_path: Path | str) -> dict:
return audio_info
def get_video_info(video_path: Path | str) -> dict:
# Set logging level
def get_video_info(
video_path: Path | str,
video_encoder_config: "VideoEncoderConfig | None" = None,
) -> dict:
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
Args:
video_path: Path to the encoded video file to probe.
video_encoder_config: If provided, record the exact encoder settings used to encode this
video. Stream-derived values take precedence encoder fields are only written for keys
not already populated from the video file itself.
"""
logging.getLogger("libav").setLevel(av.logging.WARNING)
# Getting video stream information
@@ -1004,7 +1319,6 @@ def get_video_info(video_path: Path | str) -> dict:
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
@@ -1018,9 +1332,67 @@ def get_video_info(video_path: Path | str) -> dict:
# Adding audio stream information
video_info.update(**get_audio_info(video_path))
# Add additional encoder configuration if provided (no override of stream-derived values)
# Depth related fields flow naturally through this path.
if video_encoder_config is not None:
for field_name, field_value in asdict(video_encoder_config).items():
video_info.setdefault(f"video.{field_name}", field_value)
# Fallback case where no encoder config is provided or the video is not a depth map.
video_info.setdefault("video.is_depth_map", False)
return video_info
# ─── Depth metadata helpers (reader side) ────────────────────────────
_DEPTH_INFO_KEYS: tuple[str, ...] = (
"video.depth_min",
"video.depth_max",
"video.shift",
"video.use_log",
)
def seed_depth_feature_info(
features: dict[str, dict],
depth_encoder_config: "DepthEncoderConfig | None",
) -> None:
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
``update_video_info`` only runs after the first episode video is encoded,
so without this seeding step ``features[key]["info"]`` carries no
quantization range until then. Consumers that read the dataset feature
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
range during episode 1 and re-normalize per frame.
Stream-derived values written later by :func:`get_video_info` /
``update_video_info`` win over these seeds (the merge is
``{**existing, **stream_info}``), so callers can safely re-run this on
a partially-populated info dict.
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
as a depth map.
"""
if depth_encoder_config is None:
return
encoder_fields = {
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
}
for ft in features.values():
if ft.get("dtype") != "video":
continue
info = ft.get("info") or {}
if not info.get("video.is_depth_map", False):
continue
# Only fill fields not already set, so explicit user-provided info is preserved.
for k, v in encoder_fields.items():
info.setdefault(k, v)
ft["info"] = info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
+2 -5
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
@@ -21,10 +23,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
@@ -45,9 +44,7 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"RewardClassifierConfig",
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+4 -3
View File
@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
num_valid = valid_mask.sum() * abs_err.shape[-1]
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
mask = in_episode_bound.unsqueeze(-1)
num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean()
+2 -30
View File
@@ -52,8 +52,6 @@ from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
@@ -89,7 +87,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -132,18 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
from .sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from .sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from .groot.modeling_groot import GrootPolicy
@@ -173,7 +163,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -200,8 +190,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
@@ -378,14 +366,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from .sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
@@ -394,14 +374,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from .sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from .groot.processor_groot import make_groot_pre_post_processors
+5 -9
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
@@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3
# config
@dataclass
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
backbone_cfg: dict
action_head_cfg: dict
action_horizon: int
action_dim: int
compute_dtype: str = "float32"
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
mask = ~batch["action_is_pad"].unsqueeze(-1)
num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean()
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
mask = ~batch["action_is_pad"].unsqueeze(-1)
num_valid = mask.sum() * loss.shape[-1]
return (loss * mask).sum() / num_valid.clamp_min(1)
return loss.mean()

Some files were not shown because too many files have changed in this diff Show More