Compare commits

..

72 Commits

Author SHA1 Message Date
Pepijn 026e4c937d fix(viz): use PyAV for AV1 video decoding, AutoImageProcessor for SigLIP
- Replace cv2.VideoCapture with PyAV (av library) which handles AV1
  codec properly. Decode each video once and index by frame number.
- Use AutoImageProcessor instead of AutoProcessor to avoid loading
  the SigLIP tokenizer (which requires sentencepiece).

Made-with: Cursor
2026-03-23 23:02:50 -07:00
Pepijn efe8c09fca chore(viz): bump SigLIP batch size to 512 for H100
Made-with: Cursor
2026-03-23 20:31:02 -07:00
Pepijn 58eecad8a4 feat(viz): add image-based consistency analysis with SigLIP
Run two parallel KNN analyses per dataset:
  1. State-based: KNN in joint-state space
  2. Image-based: KNN in SigLIP embedding space (google/siglip-base-patch16-224)

Both measure action chunk variance among cross-episode neighbors.
Comparing them reveals whether visual and proprioceptive similarity
agree on where data is inconsistent.

Output is a 4-row figure: state histogram, image histogram,
overlaid per-episode curves, and spatial heatmap colored by
image-based variance.

Made-with: Cursor
2026-03-23 20:29:36 -07:00
Pepijn c7fd1f47d1 fix(viz): use relative output path instead of hardcoded absolute path
Made-with: Cursor
2026-03-23 20:17:59 -07:00
Pepijn 6370949e5c feat(viz): add dataset quality visualization tools
Add three new analysis scripts for dataset quality insight:
- create_frame_grid.py: random frame grid JPG for visual inspection
- workspace_density.py: 3D TCP trajectory clustering with K-means
- action_consistency.py: KNN-based action-state consistency analysis
  with action chunk support (default chunk=30) matching policy learning

Also update create_progress_videos.py with configurable camera selection.

Made-with: Cursor
2026-03-23 20:15:15 -07:00
Pepijn 46b97da168 Merge branch 'feat/analysis_dataset' of https://github.com/huggingface/lerobot into feat/analysis_dataset 2026-03-18 11:44:17 -07:00
Pepijn e69be57a66 precommit 2026-03-18 10:19:29 -07:00
Pepijn c3a6ddb668 Merge branch 'main' into feat/analysis_dataset 2026-03-18 10:16:24 -07:00
Pepijn dad661012d nit 2026-03-18 10:15:16 -07:00
Pepijn 219c08ccb8 add example for creating progress video for sarm 2026-03-18 10:13:46 -07:00
Altman e64fa667c3 fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors (#3128)
* fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors

When VQ discretization phase completes, the code was overwriting
register_buffer('discretized') and register_buffer('freeze_codebook')
with torch.tensor(True), which is created on CPU. DDP then fails in
_sync_buffers() with: RuntimeError: No backend type associated with
device type cpu. Fix by updating the buffers in-place with .fill_(True)
so device and registration are preserved.

Made-with: Cursor

* test(vqbet): add regression test for in-place buffer update during discretization

Verifies that discretize() updates the 'discretized' and 'freeze_codebook'
registered buffers in-place (via fill_()) rather than replacing them with new
CPU tensors. The test checks data_ptr() identity and that the tensors remain
registered buffers after the call. This prevents regressions of the DDP fix.

Made-with: Cursor

* test(vqbet): add GPU regression test to verify buffers stay on CUDA after discretize()

Directly catches the original DDP failure mode: when buffers are replaced with
torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type
associated with device type cpu' in _sync_buffers(). The GPU test places the
model on cuda:0 and asserts both buffers remain on CUDA after discretization.

Made-with: Cursor

* test(vqbet): simplify to single device-check test in test_policies.py

Per reviewer feedback: remove the separate test file and replace the two
CPU/GPU tests (with data_ptr checks) with a single focused test in
tests/policies/test_policies.py that only asserts the registered buffers
remain on the model device after discretize(). Uses DEVICE from tests/utils.py
so it runs on whatever device the CI/user selects (cpu, cuda, mps).

Made-with: Cursor

* style: fix import order in test_policies.py to pass ruff/pre-commit checks

Made-with: Cursor

---------

Co-authored-by: Zhan DiJia <2476100824@example.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
2026-03-18 13:24:07 +01:00
Khalil Meftah d9ec3a6fa2 Fix/earth rover dataset features (#3088)
* docs(earthrover): update EarthRover Mini Plus dataset features and descriptions

* refactor(teleop): rename rover action keys to linear_velocity/angular_velocity

* fix(earthrover): align observation and action features with frodobots/berkeley-frodobots-lerobot-7k

* chore: address PR review comments

* ci: retrigger checks
2026-03-17 18:33:53 +01:00
Steven Palma d90e4bcfd3 refactor(dataset): modular files (#3171)
* refactor(dataset): modular files

* refactor(dataset): update imports across the codebase
2026-03-15 23:58:09 -07:00
Steven Palma 9d3b62aa61 chore(dataset): basic house-keeping (#3170) 2026-03-15 22:12:09 -07:00
Steven Palma 7c2ec31793 refactor(datasets): module cleanup (#3169) 2026-03-15 20:42:15 -07:00
Steven Palma a07b1d76f1 chore(dependecies): untangle dependecies across internal modules (#3149) 2026-03-15 20:26:06 -07:00
Caroline Pascal 2ec1dafcc2 fix(lerobot-train): fixing lerobot-train --help by removing % in the docstrings (draccus does not support the character) (#3161) 2026-03-14 10:49:53 -07:00
Pepijn 06385902df Add create reward visualization and multimodal analysis tool 2026-03-13 09:28:26 -07:00
Caroline Pascal 2d6259156b fix(links): replacing relative links with absolute links in the contribution guide (#3141)
* fix(links): replacing relative links with absolute links in the contribution guide

* fix(links): replacing relative link in the README
2026-03-12 20:46:05 -07:00
Bruno Machado 0db5f66dda Add option to disable tags on WandB (#1339)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-11 16:54:08 -07:00
Steven Palma efee611403 fix(policies): crop losses based on the action dof (#3133)
Co-authored-by: Chenning Yu <rainorangelemon@gmail.com>
2026-03-11 16:51:31 -07:00
Heuzef c15b75e3da Update Dockerfile.user (#1633)
Instruction for USB ports access with container

Signed-off-by: Heuzef <contact@heuzef.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-11 16:45:43 -07:00
H.Yamada f311ca3dce Fix action padding key at SmolVLA (#1717)
Issue https://github.com/huggingface/lerobot/issues/1707

Action padding mask is set at LeRobotDataset as f"{key}_is_pad".

Wrong key doesn't raise any errors, however, padding mask is ignored,
resulting wrong attention at around the edges of an episode
when multi step actions is enabled (aka. action horizon is greater
than 1).

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-03-11 12:12:21 -07:00
Silvio Traversaro 19c6adef85 chore(dependencies): Increase opencv-python-headless upper bound (#3120)
Signed-off-by: Silvio Traversaro <silvio@traversaro.it>
2026-03-09 23:27:18 +01:00
Johnson Sun 96b7f3dae0 Parse HF_USER with NO_COLOR to avoid incorrectly capturing bash ANSI codes (#3119) 2026-03-09 18:47:58 +01:00
Martino Russi 885ef91892 fix(unitree_g1): correct SDK detection and update installation docs (#3115)
* update docs

* update toml / docs

* update docs

* fix joystick

* Update pyproject.toml

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* update toml and docs

* update docs

* clarify robot

* update docs

* update docs

* update pinocchio deps

* final touches

* Update docs/source/unitree_g1.mdx

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* move envhub dependencies to docs

* point to unitree_sdk docs

* upper bound on onnx

* chore(docs): small details unitree docs

* chore(deps): add version pin and unitree_sdk hint

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-09 18:47:12 +01:00
Steven Palma b0efa73520 chore(dependencies): Bump lerobot to 0.5.1 (#3118) 2026-03-09 12:43:32 +01:00
Steven Palma 00b662de02 chore(dependencies): Bump lerobot to 0.5.0 (#3117) 2026-03-09 11:34:52 +01:00
Steven Palma 5c51a74484 chore(deps): update requirements file (#3114) 2026-03-09 11:18:05 +01:00
Steven Palma db8547e35d test(cameras): skip flaky async_read test (#3106) 2026-03-08 14:02:33 +01:00
Steven Palma c17d949531 chore(readme): update citation with ICLR26 paper (#3107)
* peer reviewed citation 🎉

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* add iclr year

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix quentin's spelling name

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* docs(readme): update citation

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
2026-03-08 14:01:43 +01:00
Steven Palma 1e131f93f8 chore(docs): add uv installation instructions (#3105)
* chore(docs): add uv installation instructions

* fix(docs): format tabs

* chore(docs): small details

* chore(docs): last details uv installation instructions

* chore(docs): last detail

---

Co-authored-by: sahilmaniyar888 <156301258+sahilmaniyar888@users.noreply.github.com>
2026-03-08 13:00:06 +01:00
Ignat Georgiev 2fb5c7add0 feat(train): add cudnn_deterministic option for reproducible training (#3102)
Add a `cudnn_deterministic` flag to `TrainPipelineConfig` (default: False)
that sets `torch.backends.cudnn.deterministic = True` and disables benchmark
mode, eliminating CUDA floating-point non-determinism at the cost of ~10-20%
training speed. When False (default) the existing benchmark=True behaviour
is preserved.
2026-03-08 12:29:33 +01:00
Martino Russi 4f2ef024d8 feat(robots): Unitree G1 WBC implementation (#2876)
* move locomotion from examples to robot, move controller to teleoperator class

* modify teleoperate to send back actions to robot

* whole body controller

* add holosoma to locomotros

* various updates

* update joint zeroing etc

* ensure safefail with locomotion

* add unitree locomotion

* launch camera from g1 server

* publish at varying framerates

* fix async read in camera

* attempting to fix camera lag

* test camera speedup

* training

* inference works

* remove logging from pi0

* remove logging

* push local changes

* testing

* final changes

* revert control_utils

* revert utils

* revert

* revert g1

* revert again:

* revert utils

* push recents

* remove examples

* remove junk

* remove mjlog

* revergt edit_dataset

* Update lerobot_edit_dataset.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* undo teleop changes

* revert logging

* remove loggings

* remove loogs

* revert dataset tools

* Update dataset_tools.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* move gravity to utils

* revert changes

* remove matplotlib viewer (rerun works fine)

* factory revert

* send policy action directly

* recent changes

* implement flexible action space

* send empty command if arms are missing

* rename locomotion to controller

* add init

* implement feedback

* add feedback for teleoperator

* fix ruff

* fix ruff

* use read_latest

* fix zmq camera

* revert exo_serial

* simplify PR

* revert exo_changes

* revert camera_zmq

* Update camera_zmq.py

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>

* remove frame duplication from zmq server

* revert channerfactoryinitialize

* keep channelfactoryinitialize

* remove zeroing out logic

* fix typo

* refactor teleop class

* simplify teleop further

* import armindex at the top

* fix visualizer again

* revert ik helper

* push stuff

* simplify image_server

* update image_server

* asd

* add threading logic

* simplify ik helper stuff

* simplify holosoma

* fix names

* fix docs

* revert leg override

* clean connect

* fix controller

* fix ruff

* clean teleoperator

* set_from_wireless

* avoid double initializations

* refactor robot class

* fix pre-commit

* update docs

* update docs format

* add teleop instructions

* unitree_g1 specific exception in record/teleoperate

* add thumbnail to docs

* add thumbnail to doc

* refactor(unitree): multiple improvements (#3103)

* refactor(unitree): multiple improvements

* test(unitree): added tests + improved installation instructions

* refactor(robots): minor changes unitree robot kinematic

* chore(robots): rename g1 kinematics file

---------

Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2026-03-08 11:33:24 +01:00
Shun.Sasaki 6139b133ca fix(async_inference): restore robot module imports in robot_client.py (#3081) 2026-03-06 17:14:14 +01:00
Steven Palma 85de893fa7 fix(ci): skip HF log in (and tests) in forks and community PRs (#3097)
* fix(ci): skip HF log in (and tests) in forks and community PRs

* chore(test): remove comment about test meant to be only run locally

* fix(tests): no hf log in decorator for xvla

* fix(test): no decorator in yield
2026-03-06 16:33:43 +01:00
Steven Palma a4c66e530b chore(docs): remove pi installation note (#3095) 2026-03-06 15:52:54 +01:00
Steven Palma a225127527 chore(dependencies): sync intelrealsense + added notes (#3094) 2026-03-06 10:50:46 +01:00
Steven Palma e489ba24fc feat(dependencies): require Python 3.12+ as minimum version (#3023)
* feat(dependecies): upgrade to python3.12

* fix(test): processor regex message

* fix(test): processor regex message

* fix(dependecies): resolve all tags in python 3.12

* fix(dependecies): add more hints to faster resolve

* chore(dependecies): remove cli tag huggingface-hub dep

* refactor(policy): update eagle for python3.12

* chore(docs): update policy creation for python 3.12

* chore(test): skip failing tests in macos
2026-03-06 10:15:13 +01:00
Steven Palma d324ffe810 fix(ci): test only multi-gpu tests in multi-gpu runner (#3092) 2026-03-05 19:53:40 +01:00
Pepijn 1a24f770d3 Feat/slurm compute rabc script (#3041)
* Add SLURM SARM progress annotation script.

Provide a standalone two-stage compute/aggregate pipeline for RA-BC progress generation so large datasets can be processed in parallel and optionally uploaded to the Hub.

Made-with: Cursor

* fix pr comments

* remove comments
2026-03-05 18:27:58 +01:00
Caroline Pascal 92fba37225 fix(num_frames): fixing redundant frames count in conversion script (#3091) 2026-03-05 15:49:50 +01:00
Steven Palma 3e45120272 fix(ci): log in HF for gated repo in nightly workflows (#3089)
* fix(ci): log in HF for gated repo in nightly workflows

* fix(ci): add env var

* fix(ci): remove 10 min limit for multi-gpu nightly
2026-03-05 13:22:37 +01:00
Steven Palma f0d2b37beb chore(dependencies): bump transformers v5 (#2964)
* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy

* chore(dependencies): bump pi0 family to transformers v5

* chore(dependencies): bump wall x to transformers v5

* chore(dependencies): bump gr00t to transformers v5

* chore(style): fix pre-commit

* fix(policy): xvla forced_bos_token missing

* test(rl): skip ci tests for resnet10

* Fix: full pi models support for transformer v5 (#2967)

* fix(pi): remove loss truncation

* fix(pi): remove state padding before tokenization

* fix(pi): fix image padding value

* fix from_pretrain

* add transformer v5 changes

* remove reference

* more fixes

* make it work

* add support for rest of pi family

* add pifast work

* more changes

* more changes

* more cleanup

* fix torch params

* dtype fix

* torch compile

* embed mismatch fix

* revert groot

* more nit fixes

* remove unused classes

* more fixes

* revert

* nit

* torch dtype warning fix

* but back dynamic renaming

* add tie embedding

---------

Co-authored-by: Yufei Sun <skieyfly@gmail.com>

* chore: fix XVLA in transformers v5 (#3006)

* test(policies): enable wall x CI testing

* style(test): pre-commit check

* style(test): pre-commit

* fix wall x for transformer v5 (#3008)

* tv5 fix

* various wall x fixes

* Delete tests/policies/pi0_pi05/print_pi05_output_logits.py

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* sync modeling_florence2.py with chore/bump_transformers_v5

* more

* more fixes

* more

* remove comment

* more

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* chore(dependencies): adjust dependencies versioning after transformers v5 (#3034)

* chore(dependecies): adjust dependecies versioning after transformers v5

* fix(policies): remove deprecated input_embeds

* fix(policies): dict _tied_weights_keys

* chore(depedencies): common qwen-vl-utils

* chore(dependencies): bump transformers to 5.2

* Fix policy testing for tv5 (#3032)

* fix ci logger

* other fix

* fix mypy

* change logits to torch2.10

* skip wallx|

* remove logging

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>

* feat(ci): log into HF to unblock some CI tests (#3007)

* feat(ci): log into HF to unblock some CI tests

* chore(ci): change hf call + secret name

* fix(ci): temp fix for pi0 rtc test

* test(policies): require_cuda for unblocked tests

* test(policies): require_cuda wall_x

* fic(tests): require_cuda outter most for pi0

* fix(test): return instead of yield

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* style(test): fix pre-commit

* chore(deps): upgrade transformers (#3050)

* chore(test): use lerobot model

* fix(policies): change default action tokenizer for wall x

* sample on cpu

* Revert "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit d9b76755f7, reversing
changes made to 89359cb0b6.

* Reapply "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5"

This reverts commit c9914db78b.

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: Yufei Sun <skieyfly@gmail.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
2026-03-05 09:25:26 +01:00
Caroline Pascal cbc8bfb2e6 chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label (#3077)
* chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label

* chore(task): renamming the default index label in the tasks DataFrame to task

* Revert "chore(docstrings): updating v2.1-v3.0 conversion script docstrings to match the new task label"

This reverts commit f55de3255278f23f18b5d955565f6768d094951d.

* chore(docstrings): updating docstrings to match dataset v3.0 architecture

* chore(format): formatting code
2026-03-04 17:59:03 +01:00
Paul Crook 0d1be72dc8 Fixing metadata indexing when writing new Parquet file (#2941)
* Fixing metadata indexing when writing new Parquet file

Summary:
  - addressing this issue: https://github.com/huggingface/lerobot/issues/2401
  - vibe-coded bugfix by Claude Sonnet 4.5

* Backing out changes to convert_videos_of_camera

* Addressing Ruff pre-commit complaint

Summary:
 - addressing "SIM113 Use `enumerate()` for index variable `ep_idx` in `for` loop"

---------

Co-authored-by: Paul <238953601+pac-robotics@users.noreply.github.com>
2026-03-04 16:53:34 +01:00
Maxime Ellerbach 96b7c212c4 chore(docs): updating deprecated huggingface-cli to hf (#3071)
* chore(docs): updating deprecated huggingface-cli to hf

* small typo in my-org
2026-03-04 15:08:49 +01:00
Caroline Pascal 4303b3c930 chore(root): fixing root semantics in convert_dataset script (#3073)
* fix(root): fixing root semantincs in convert_dataset script

* fix(\): fixing command syntax in dataset conversion script

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-03-04 11:11:21 +01:00
Caroline Pascal 63dca86df8 fix(dataset edit tools): clarifying root argument usage + adding related features (#3049)
* fix(root): adding proper support for the root and new_root arguments

* feat(roots): adding a roots agrument for the merge operation

* chore(clean): cleaning up code

* chore(doctrings): updating doctrings with new features

* fix(repo_id): setting repo_id to None when not needed

* fix(roots/repo_ids): making mypy happy by using repo_ids and roots for merge operation

* fix(path): fixing path related issues

* fix(repo_id): fixing issues related to repo_id

* chore(doctrings): updating docstrings + fix typo

* chore(clean): cleaning code

* fix(split new_repo_id): reverting new_repo_id addition for split operation

* docs(dosctrings): completing docstrings

* fix(repo_ids/roots): improving checks for repo_ids/roots lengths

* fix(repo_ids): making repo_ids optional in MergeConfig but raise if not given

* fix(docstrings): fixing docstrings for split operation

* fix(hints): updating get_output_path hints to accept paths as strings too

* fix(y/N prompts): removing y/N prompts in lerobot_edit_dataset

* fix(merge repo_id): fixing merge operation to use new_repo_id instead of repo_id

* fix(typo): fixing typo in doctrings
2026-03-03 15:40:46 +01:00
Caroline Pascal 8a0cc3d664 fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets (#3068)
* fix(frame_index): making rerun's "frame_index" timeline compatible with behaviour1k datasets

* fix(segfault risk): removing segfault risk by calling  batch["index"] in the dataloader loop
2026-03-03 11:55:09 +01:00
Bernie Telles 8bb8ed4803 Improve policy_device documentation for async.mdx (#3060) 2026-03-02 15:35:15 +01:00
Steven Palma 095856b06a chore: add AI policy (#3055) 2026-02-28 14:41:28 +01:00
Steven Palma 563f42bdb1 chore(dependencies): Bump lerobot to 0.4.5 (#3051) 2026-02-27 19:29:35 +01:00
Caroline Pascal 8fff0fde7c chore(docstrings): fixing deprecated root argument description in LeRobotDataset class (#3035)
* chore(docstrings): fixing deprecated `root` argument docstrings in LeRobotDataset class

* chore(draccus): updating draccus CLI help

* chore(revert): reverting changes in lerobot_dataset_viz.py

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:22:44 +01:00
Pepijn 04de496547 fix(logging): avoid double-counting samples across processes (#3045) 2026-02-27 17:45:19 +01:00
Khalil Meftah baf9b50365 Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)
* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio

* fix(diffusion): address review feedback on resize/crop backward compat

* test: regenerate diffusion artifacts for updated default config

* fix: disable crop when resize path uses crop_ratio=1.0

---------

Co-authored-by: starlitxiling <1754165401@qq.com>
2026-02-27 17:44:53 +01:00
Jade Choghari a0fdbf037a feat(policies): add Smolvla torch compile support (#3043)
* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Change LIBERO init_state_id when reset.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* pre-commit run

* Add torch.compile for smolvla

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* Add torch.compile for smolvla

Add model compilation option for improved performance.

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>

* first

---------

Signed-off-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Aoqun Jin <aojiaojiao@foxmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-27 18:58:36 +03:00
Khalil Meftah c085531b17 fix: add missing openarm_mini import to CLI scripts (#3028) 2026-02-27 15:46:31 +01:00
Steven Palma c7c6205332 chore(scripts): no spam log when no action (#3042) 2026-02-27 15:26:56 +01:00
Michio Sun 4e54be1334 fix(datasets): skip warning when MultiLeRobotDataset features are identical (#3019)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 17:42:22 +01:00
Damien LaRocque fde9d08281 feat(async_inference) Enable plugins with async inference (#2425)
* feat(async-inference) Try using async inference server with plugins

* Fix import

* Fix import error in Robot Client

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-26 14:41:32 +01:00
Khalil Meftah 46044fed75 Fix: remove device_map from SmolVLA model loading (#3029)
* Fix SmolVLA meta tensor error by removing device_map

- Remove device_map parameter from VLM model loading
- Change torch_dtype from string to torch.bfloat16
- Add explicit .to(device) calls after initialization

This resolves NotImplementedError when training SmolVLA policy.
Fixes meta tensor copy issue in factory.py:418.

* fix: remove manual device movement logic and fix dtype handling

---------

Co-authored-by: Highsky7 <albert31115@gmail.com>
2026-02-26 13:28:46 +01:00
Khalil Meftah 975dcad918 Feat(teleoperators): add OpenArm Mini teleoperator (#3022)
* add OpenArm Mini config and module init

* add OpenArm Mini teleoperator implementation

* add OpenArm Mini into factory and setup motors

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 18:46:55 +01:00
Cotton Hu d0b58190da fix(policies): support dp train when n_obs_steps=1 (#2430)
Co-authored-by: hukongtao <hukongtao@agibot.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-25 17:36:31 +01:00
Mishig 9a5ab8ffab feat: add visualization badge to card template and update dataset card creation with repo_id (#3005)
* feat: add visualization badge to card template and update dataset card creation with repo_id

* Update src/lerobot/datasets/card_template.md

* Update src/lerobot/datasets/card_template.md

---------

Signed-off-by: Mishig <dmishig@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-25 16:02:40 +01:00
Khalil Meftah 7541d72130 Fix SARM dense_only mode: always load episodes_df for target computation (#3021)
* fix annotation mode check

* fix: SARM dense_only mode always load episodes_df for target computation

---------

Co-authored-by: John Newsom <jackmnewsom@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-02-25 13:28:01 +01:00
Jash Shah 0317a15bf1 fix(video): replace assertions with proper exceptions in video frame decoding (#3016)
Replaced assert statements with FrameTimestampError exceptions in
decode_video_frames_torchvision and decode_video_frames_torchcodec.

Assertions are unsuitable for runtime validation because they can be
silently disabled with python -O, and they produce unhelpful
AssertionError tracebacks. The codebase already defines
FrameTimestampError for this exact purpose but it was only used
in one of the three validation sites.

Also removed AssertionError from the except clause in
LeRobotDataset.__init__, which was masking video timestamp errors
by silently triggering a dataset re-download instead of surfacing
the actual problem.
2026-02-25 12:29:22 +01:00
Jash Shah f138e5948a Fix metaworld_config.json not bundled in pip installs and AttributeError crash (#3017)
1. Include metaworld_config.json in package distributions by adding it to
   both MANIFEST.in (for sdist) and pyproject.toml package-data (for wheels).
   Without this, pip-installed lerobot raises FileNotFoundError when
   importing the metaworld environment.

2. Fix crash in sanity_check_dataset_name where the error message accesses
   policy_cfg.type when policy_cfg is None, raising AttributeError instead
   of the intended ValueError.

Fixes #2958
2026-02-25 12:29:10 +01:00
Martin Kiefel 8fef4ddab8 fix(dataset): Fix reindexing bug for videos on splits (#2548)
* fix(dataset): Reindex videos based on frame and not on time

Sometimes during split operations the frame timestamp floating
precision leads to frame ending up in the wrong split.

This changes fixes the issues by directly working with frame indices
instead.

* Fix formatting
2026-02-25 11:57:07 +01:00
Steven Palma 18d9cb5ac4 feat(scripts): Integrate tqdm for training progress visualization (#3010) 2026-02-24 19:10:43 +01:00
Steven Palma 5095ab0845 fix(ci): permissions triton (#3011) 2026-02-24 19:09:34 +01:00
Jash Shah dac1efd13d feat: Enable torch.compile for DiffusionPolicy inference (#2486)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-24 17:29:08 +01:00
236 changed files with 8234 additions and 4457 deletions
+2 -1
View File
@@ -44,7 +44,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -91,6 +91,7 @@ jobs:
run: uv sync --extra "test"
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
+5 -1
View File
@@ -37,7 +37,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest action is built, canceling older runs.
@@ -89,6 +89,7 @@ jobs:
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
@@ -181,9 +182,12 @@ jobs:
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
+20 -4
View File
@@ -28,7 +28,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
@@ -119,6 +119,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
@@ -130,6 +131,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on CPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -146,6 +152,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -157,6 +164,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -174,6 +186,7 @@ jobs:
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -185,12 +198,15 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Verify GPU availability
run: |
nvidia-smi
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10
run: pytest -vv tests/training/
+1 -1
View File
@@ -50,7 +50,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Run pre-commit hooks
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
+2 -10
View File
@@ -22,7 +22,7 @@ on:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
jobs:
# This job builds the Python package and publishes it to PyPI
@@ -45,7 +45,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
- name: Extract Version
id: extract_info
@@ -83,14 +83,6 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
+13 -2
View File
@@ -29,7 +29,7 @@ permissions:
# Sets up the environment variables
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
PYTHON_VERSION: "3.12"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
# Ensures that only the latest action is built, canceling older runs.
@@ -48,6 +48,7 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -79,7 +80,11 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -137,6 +142,7 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -148,6 +154,11 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv
- name: Run end-to-end tests
+2 -2
View File
@@ -13,7 +13,7 @@
# limitations under the License.
default_language_version:
python: python3.10
python: python3.12
exclude: "tests/artifacts/.*\\.safetensors$"
@@ -55,7 +55,7 @@ repos:
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
args: [--py312-plus]
##### Markdown Quality #####
- repo: https://github.com/rbubley/mirrors-prettier
+25
View File
@@ -0,0 +1,25 @@
# AI Usage Policy
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
## Remember the Human Maintainers
Please remember that LeRobot is maintained by a dedicated team of humans.
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
## AI is Welcome Here
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
+4 -4
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) and our [AI policy](https://github.com/huggingface/lerobot/blob/main/AI_POLICY.md).
## Ways to Contribute
@@ -32,7 +32,7 @@ git remote add upstream https://github.com/huggingface/lerobot.git
### 2. Environment Installation
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
Please follow our [Installation Guide](https://huggingface.co/docs/lerobot/installation) for the environment setup & installation from source.
## Running Tests & Quality Checks
@@ -75,8 +75,8 @@ pytest -sv tests/test_specific_feature.py
Use the templates for required fields and examples.
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
One member of the LeRobot team will then review your contribution.
+1
View File
@@ -1,2 +1,3 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
+19 -2
View File
@@ -135,7 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
## Citation
If you use LeRobot in your research, please cite:
If you use LeRobot in your project, please cite the GitHub repository to acknowledge the ongoing development and contributors:
```bibtex
@misc{cadene2024lerobot,
@@ -146,9 +146,26 @@ If you use LeRobot in your research, please cite:
}
```
If you are referencing our research or the academic paper, please also cite our ICLR publication:
<details>
<summary><b>ICLR 2026 Paper</b></summary>
```bibtex
@inproceedings{cadenelerobot,
title={LeRobot: An Open-Source Library for End-to-End Robot Learning},
author={Cadene, Remi and Alibert, Simon and Capuano, Francesco and Aractingi, Michel and Zouitine, Adil and Kooijmans, Pepijn and Choghari, Jade and Russi, Martino and Pascal, Caroline and Palma, Steven and Shukor, Mustafa and Moss, Jess and Soare, Alexander and Aubakirova, Dana and Lhoest, Quentin and Gallou\'edec, Quentin and Wolf, Thomas},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://arxiv.org/abs/2602.22818}
}
```
</details>
## Contribute
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
<p align="center">
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
+3 -1
View File
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
# Define Python version argument
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
# Configure environment variables
ENV DEBIAN_FRONTEND=noninteractive \
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+3 -1
View File
@@ -18,8 +18,10 @@
# docker build -f docker/Dockerfile.user -t lerobot-user .
# docker run -it --rm lerobot-user
# With USB physical access : docker run -it --device=/dev/ -v /dev/:/dev/ --rm lerobot-user
# Configure the base image
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim
# Configure environment variables
+1 -1
View File
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
+4 -4
View File
@@ -32,7 +32,7 @@ version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any
from typing import Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
@@ -102,7 +102,7 @@ Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Dict, Any
from typing import Any
import torch
+14 -10
View File
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
### Hardware
- EarthRover Mini robot
- Computer with Python 3.10 or newer
- Computer with Python 3.12 or newer
- Internet connection
### Setting Up the Frodobots SDK
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
@@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
Your dataset includes:
**Your Actions (2 things)**:
**Your Actions (2 features)**:
- How much you moved forward/backward
- How much you turned left/right
- `linear_velocity`: How much you moved forward/backward
- `angular_velocity`: How much you turned left/right
**Robot Observations (12 things)**:
**Robot Observations (24 features)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Orientation
- GPS (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp status (on/off)
- Lamp state (on/off)
- Accelerometer (x, y, z)
- Gyroscope (x, y, z)
- Magnetometer (x, y, z)
- Wheel RPMs (4 wheels)
### Where Your Data Goes
+2 -2
View File
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
hf auth login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
hf repo create my-org/my-custom-env
# Initialize git and push
git init
+5 -5
View File
@@ -159,13 +159,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
You can also push your local dataset to the Hub manually, running:
```bash
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
```
#### Record function
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
Once training is done, upload the latest checkpoint with:
```bash
huggingface-cli upload ${HF_USER}/act_so101_test \
hf upload ${HF_USER}/act_so101_test \
outputs/train/act_so101_test/checkpoints/last/pretrained_model
```
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
hf upload ${HF_USER}/act_so101_test${CKPT} \
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
```
+60 -13
View File
@@ -1,8 +1,8 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
This guide uses `conda` (via miniforge) to manage environments (recommended). If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and `ffmpeg` installed with the `libsvtav1` encoder, then skip ahead to [Environment Setup](#step-2-environment-setup).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Step 1 (`conda` only): Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
@@ -11,22 +11,47 @@ bash Miniforge3-$(uname)-$(uname -m).sh
## Step 2: Environment Setup
Create a virtual environment with Python 3.10, using conda:
Create a virtual environment with Python 3.12:
<!-- prettier-ignore-start -->
<hfoptions id="create_venv">
<hfoption id="conda">
```bash
conda create -y -n lerobot python=3.10
conda create -y -n lerobot python=3.12
```
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
</hfoption>
<hfoption id="uv">
```bash
uv python install 3.12
uv venv --python 3.12
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
Then activate your virtual environment, you have to do this each time you open a shell to use lerobot:
<!-- prettier-ignore-start -->
<hfoptions id="activate_venv">
<hfoption id="conda">```bash
conda activate lerobot
```</hfoption>
<hfoption id="uv">
```bash
# Linux/macOSsource
source .venv/bin/activate
# Windows PowerShell
source .venv\Scripts\Activate.ps1
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
When using `conda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
ffmpeg -version # ffmpeg 8.X is not yet supported !
```
> [!TIP]
@@ -47,6 +72,9 @@ conda install ffmpeg -c conda-forge
> conda install evdev -c conda-forge
> ```
> [!IMPORTANT]
> If you are using `uv` you will have to install `ffmpeg` system-wide (outside of the virtual environment). You rely on `uv` and `torchcodec` ability to dynamically link to the system `ffmpeg`.
## Step 3: Install LeRobot 🤗
### From Source
@@ -60,23 +88,45 @@ cd lerobot
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
<!-- prettier-ignore-start -->
<hfoptions id="install_lerobot_src">
<hfoption id="conda">
```bash
pip install -e .
```
</hfoption>
<hfoption id="uv">
```bash
uv pip install -e .
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
### Installation from PyPI
**Core Library:**
Install the base package with:
<!-- prettier-ignore-start -->
<hfoptions id="install_lerobot_pypi">
<hfoption id="conda">
```bash
pip install lerobot
```
</hfoption>
<hfoption id="uv">
```bash
uv pip install lerobot
```
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
_This installs only the default dependencies._
**Extra Features:**
To install additional functionality, use one of the following:
To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
```bash
pip install 'lerobot[all]' # All available features
@@ -90,13 +140,10 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
To install these for linux run:
To install these for Linux run:
```bash
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
@@ -106,7 +153,7 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/
## Optional dependencies
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`.
LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
### Simulations
+2 -2
View File
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
Add your token to the CLI by running this command:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
-5
View File
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
-5
View File
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
-5
View File
@@ -43,11 +43,6 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training a Custom FAST Tokenizer
You have two options for the FAST tokenizer:
+200 -205
View File
@@ -1,23 +1,72 @@
# Unitree G1
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/unitree_thumbnail.jpg"
alt="Unitree G1 locomanipulation demo"
style={{ width: "100%" }}
/>
## About
We support both 29 and 23 DOF G1 EDU version. We introduce:
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
- **Simulation mode** for testing policies without the physical robot in mujoco
The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train locomanipulation policies, test in sim, and more. Both 29 and 23 DoF variants are supported.
---
## Connection guide
## Part 1: Getting Started
### Step 1: Configure Ethernet Interface
### Install the Unitree SDK
Set a static IP on the same subnet as the robot:
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`:
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
pip install -e .
cd ..
```
### Install LeRobot
```bash
conda install ffmpeg -c conda-forge
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### Test the Installation (Simulation)
The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main).
```bash
pip install mujoco loguru msgpack msgpack-numpy
```
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \
--display_data=true \
--robot.controller=GrootLocomotionController
```
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`.
- Press `9` to release the robot
- Press `7` / `8` to increase / decrease waist height
### Connect to the Physical Robot
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
```bash
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
@@ -26,47 +75,23 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up
```
**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164.
### Step 2: SSH into the Robot
### SSH into the Robot
```bash
ssh unitree@192.168.123.164
# Password: 123
```
You should now be connected to the G1's Orin.
### Share Internet via Ethernet
---
## Part 2: Enable WiFi on the Robot
Wlan0 is disabled by default on the G1. To enable it:
### Step 1: Enable WiFi Hardware
```bash
sudo rfkill unblock wifi
sudo rfkill unblock all
# Bring up wlan0
sudo ip link set wlan0 up
# Enable NetworkManager control of wlan0
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
### Step 2: Enable Internet Forwarding
The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet:
**On your laptop:**
```bash
# Enable IP forwarding
sudo sysctl -w net.ipv4.ip_forward=1
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
# Replace wlp132s0f0 with your WiFi interface name
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
@@ -75,223 +100,193 @@ sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
**On the G1:**
```bash
# Add laptop as default gateway
sudo ip route del default 2>/dev/null || true
sudo ip route add default via 192.168.123.200 dev eth0
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
# Test connection
# Verify
ping -c 3 8.8.8.8
```
### Step 3: Connect to WiFi Network
### Install the Unitree SDK on the G1
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation):
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
python -m pip install -e .
cd ..
```
### Install LeRobot on the G1
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
python -m pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### (Optional) Enable WiFi on the Robot
For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default):
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**Connect to a WiFi network:**
```bash
# List available networks
nmcli device wifi list
# Connect to your WiFi (example)
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
sudo nmcli connection up "YourNetwork"
# Check WiFi IP address
ip a show wlan0
```
### Step 4: SSH Over WiFi
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
You can then SSH over WiFi instead of Ethernet:
```bash
ssh unitree@<YOUR_ROBOT_IP>
ssh unitree@<ROBOT_WIFI_IP>
# Password: 123
```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
---
## Part 2: Teleoperation & Locomotion
### Run the Robot Server
On the robot (from `~/lerobot`):
```bash
cd ~/lerobot
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
```
### Run the Locomotion Policy
You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network.
**From your laptop:**
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.robot_ip=<ROBOT_IP> \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--display_data=true \
--robot.controller=HolosomaLocomotionController
```
We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`.
---
## Part 3: Robot Server Setup
## Part 3: Loco-Manipulation with the Homunculus Exoskeleton
### Step 1: Install LeRobot on the Orin
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo).
SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
### Step 2: Run the Robot Server
On the robot:
```bash
python src/lerobot/robots/unitree_g1/run_g1_server.py
```
**Important**: Keep this terminal running. The server must be active for remote control.
---
## Part 4: Controlling the robot
With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
### Step 2: Update Robot IP in Config
Edit the config file to match your robot's WiFi IP:
```python
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
```
### Step 3: Run the Locomotion Policy
```bash
# Run GR00T locomotion controller
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
# Run Holosoma locomotion controller
python examples/unitree_g1/holosoma_locomotion.py
```
Press `Ctrl+C` to stop the policy.
---
## Running in Simulation Mode (MuJoCo)
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
### Calibrate Exoskeleton Teleoperator
### Calibrate
```bash
lerobot-calibrate \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo
```
### Teleoperate in Simulation
During calibration move each joint through its entire range. After fitting, move the joint in a neutral position and press `n` to advance.
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
```
### Record Dataset in Simulation
### Record a Dataset
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2
```
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
> **Note:** Omit `--teleop.left_arm_config.port` and `--teleop.right_arm_config.port` if you're only using the joystick.
Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/datasets/nepyope/unitree_box_move_blue_full)
---
## Running on Real Robot
## Part 4: Training & Inference
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
### Start the Camera Server
On the robot, start the ZMQ image server:
### Train
```bash
python src/lerobot/cameras/zmq/image_server.py
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/dataset-name \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
--job_name=pi05_training \
--policy.repo_id=your-username/your-repo-id \
--policy.pretrained_path=lerobot/pi05_base \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--policy.dtype=bfloat16 \
--policy.freeze_vision_encoder=false \
--policy.train_expert_only=false \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
Keep this running in a separate terminal for camera streaming during recording.
### Inference with RTC
### Teleoperate Real Robot
Once trained, we recommend deploying policies using inference-time RTC:
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--fps=100
python examples/rtc/eval_with_real_robot.py \
--policy.path=your-username/your-repo-id \
--policy.device=cuda \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.controller=HolosomaLocomotionController \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "<ROBOT_IP>", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--task="task_description" \
--duration=1000 \
--fps=30 \
--rtc.enabled=true
```
### Record Dataset on Real Robot
```bash
lerobot-record \
--robot.type=unitree_g1 \
--robot.is_simulation=false \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--teleop.type=unitree_g1 \
--teleop.left_arm_config.port=/dev/ttyACM1 \
--teleop.right_arm_config.port=/dev/ttyACM0 \
--teleop.id=exo \
--dataset.repo_id=your-username/dataset-name \
--dataset.single_task="Test" \
--dataset.num_episodes=2 \
--dataset.episode_time_s=5 \
--dataset.reset_time_s=5 \
--dataset.push_to_hub=true \
--dataset.streaming_encoding=true \
# --dataset.vcodec=auto \
--dataset.encoder_threads=2
```
**Note**: Update `server_address` to match your robot's camera server IP.
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
---
## Additional Resources
@@ -300,8 +295,8 @@ Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/da
- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl)
- [Holosoma](https://github.com/amazon-far/holosoma)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
- [Unitree IL LeRobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
---
_Last updated: December 2025_
_Last updated: March 2026_
+1 -1
View File
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
repo_id: str
# Episode to replay.
episode: int
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
+2 -1
View File
@@ -32,7 +32,8 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
+490
View File
@@ -0,0 +1,490 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SLURM-distributed SARM RA-BC annotation pipeline.
Computes SARM progress values for all frames in a dataset, distributed across
SLURM workers, then merges the shards into a single sarm_progress.parquet.
Two subcommands, each a separate SLURM submission:
compute N workers, each computes progress for a subset of episodes
aggregate 1 worker, merges N shards into sarm_progress.parquet, pushes to hub
Usage:
python slurm_compute_rabc.py compute \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--stride 10 --device cpu --workers 50 --partition cpu
python slurm_compute_rabc.py aggregate \\
--repo-id user/dataset --reward-model-path user/sarm_model \\
--partition cpu --push-to-hub
"""
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
class ComputeProgressShards(PipelineStep):
"""Each worker computes SARM progress for its assigned episodes."""
def __init__(
self, repo_id, reward_model_path, stride=1, head_mode="sparse", device="cpu", shard_dir="rabc_shards"
):
super().__init__()
if stride < 1:
raise ValueError(f"stride must be >= 1, got {stride}")
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.stride = stride
self.head_mode = head_mode
self.device = device
self.shard_dir = shard_dir
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.policies.sarm.compute_rabc_weights import (
generate_all_frame_indices,
interpolate_progress,
load_sarm_resources,
)
from lerobot.utils.utils import init_logging
init_logging()
dataset, reward_model, preprocess = load_sarm_resources(
self.repo_id,
self.reward_model_path,
self.device,
)
if hasattr(preprocess, "eval"):
preprocess.eval()
for step in preprocess.steps:
if hasattr(step, "eval"):
step.eval()
image_key = reward_model.config.image_key
state_key = reward_model.config.state_key
frame_gap = reward_model.config.frame_gap
center_idx = reward_model.config.n_obs_steps // 2
dual_mode = reward_model.config.uses_dual_heads
compute_sparse = self.head_mode in ("sparse", "both") or not dual_mode
compute_dense = self.head_mode in ("dense", "both") and dual_mode
my_episodes = list(range(dataset.num_episodes))[rank::world_size]
if not my_episodes:
logging.info(f"Rank {rank}: no episodes assigned")
return
logging.info(f"Rank {rank}: {len(my_episodes)} / {dataset.num_episodes} episodes")
all_rows = []
for ep_idx in tqdm(my_episodes, desc=f"Rank {rank}"):
ep = dataset.meta.episodes[ep_idx]
ep_start, ep_end = ep["dataset_from_index"], ep["dataset_to_index"]
task = dataset[ep_start].get("task", "perform the task")
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
if self.stride > 1:
compute_indices = [i for i in all_ep_indices if (i - ep_start) % self.stride == 0]
if (ep_end - 1) not in compute_indices:
compute_indices.append(ep_end - 1)
compute_indices = sorted(set(compute_indices))
else:
compute_indices = all_ep_indices
frame_results = {}
for qi in tqdm(compute_indices, desc=f" Ep {ep_idx}", leave=False):
try:
sample = dataset[qi]
batch = {
image_key: sample[image_key],
"task": task,
"index": qi,
"episode_index": ep_idx,
}
if state_key in sample:
batch[state_key] = sample[state_key]
with torch.no_grad():
processed = preprocess(batch)
vf = processed["video_features"].to(self.device)
tf = processed["text_features"].to(self.device)
sf = processed.get("state_features")
if sf is not None:
sf = sf.to(self.device)
lengths = processed.get("lengths")
sparse_val = dense_val = np.nan
if compute_sparse:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="sparse",
)
sparse_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
if compute_dense:
r = reward_model.calculate_rewards(
text_embeddings=tf,
video_embeddings=vf,
state_features=sf,
lengths=lengths,
return_all_frames=True,
head_mode="dense",
)
dense_val = float(r[0, center_idx] if r.ndim == 2 else r[center_idx])
frame_results[qi] = (sparse_val, dense_val)
except Exception as e:
logging.warning(f"Failed frame {qi}: {e}")
if not frame_results:
logging.warning(f"Episode {ep_idx}: all frames failed, skipping")
continue
# Interpolate to all frames in this episode
computed_idx = np.array(sorted(frame_results.keys()))
all_frame_arr = np.arange(ep_start, ep_end)
sparse_vals = np.array([frame_results[i][0] for i in computed_idx]) if compute_sparse else None
dense_vals = np.array([frame_results[i][1] for i in computed_idx]) if compute_dense else None
if self.stride > 1 and len(computed_idx) > 1:
if compute_sparse:
sparse_vals = interpolate_progress(computed_idx, sparse_vals, all_frame_arr)
if compute_dense:
dense_vals = interpolate_progress(computed_idx, dense_vals, all_frame_arr)
output_frames = all_frame_arr
else:
# Use only successfully computed frames to avoid indexing mismatch on failures
output_frames = computed_idx
for i, fi in enumerate(output_frames):
row = {"index": int(fi), "episode_index": ep_idx, "frame_index": int(fi - ep_start)}
if compute_sparse:
row["progress_sparse"] = float(sparse_vals[i])
if compute_dense:
row["progress_dense"] = float(dense_vals[i])
all_rows.append(row)
if all_rows:
import pandas as pd
df = pd.DataFrame(all_rows).sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
shard_dir = Path(self.shard_dir)
shard_dir.mkdir(parents=True, exist_ok=True)
out = shard_dir / f"shard_{rank:05d}.parquet"
pq.write_table(table, out)
logging.info(f"Rank {rank}: saved {len(df)} rows to {out}")
class AggregateProgress(PipelineStep):
"""Merge all shard parquets into final sarm_progress.parquet."""
def __init__(self, repo_id, reward_model_path, shard_dir="rabc_shards", push_to_hub=False):
super().__init__()
self.repo_id = repo_id
self.reward_model_path = reward_model_path
self.shard_dir = shard_dir
self.push_to_hub = push_to_hub
def run(self, data=None, rank: int = 0, world_size: int = 1):
import datetime
import logging
import os
from pathlib import Path
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import init_logging
init_logging()
if rank != 0:
return
shard_dir = Path(self.shard_dir)
shards = sorted(shard_dir.glob("shard_*.parquet"))
if not shards:
raise FileNotFoundError(f"No shards found in {shard_dir}")
# Log shard modification time range to help detect stale files
mtimes = [os.path.getmtime(s) for s in shards]
oldest = datetime.datetime.fromtimestamp(min(mtimes)).isoformat(timespec="seconds")
newest = datetime.datetime.fromtimestamp(max(mtimes)).isoformat(timespec="seconds")
logging.info(f"Aggregating {len(shards)} shards (oldest: {oldest}, newest: {newest})")
df = pd.concat([pd.read_parquet(s) for s in shards], ignore_index=True)
df = df.sort_values("index").reset_index(drop=True)
table = pa.Table.from_pandas(df, preserve_index=False)
table = table.replace_schema_metadata({b"reward_model_path": self.reward_model_path.encode()})
temp_ds = LeRobotDataset(self.repo_id, download_videos=False)
out_path = Path(temp_ds.root) / "sarm_progress.parquet"
out_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(table, out_path)
logging.info(f"Saved {len(df)} rows to {out_path}")
for col in ["progress_sparse", "progress_dense"]:
if col in df.columns:
v = df[col].dropna()
logging.info(
f"{col}: mean={v.mean():.4f} std={v.std():.4f} min={v.min():.4f} max={v.max():.4f}"
)
if self.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
hub_path = "sarm_progress.parquet"
logging.info(f"Uploading to {self.repo_id}/{hub_path}")
api.upload_file(
path_or_fileobj=str(out_path),
path_in_repo=hub_path,
repo_id=self.repo_id,
repo_type="dataset",
)
logging.info(f"Uploaded: https://huggingface.co/datasets/{self.repo_id}/blob/main/{hub_path}")
def make_compute_executor(
repo_id,
reward_model_path,
stride,
head_mode,
device,
shard_dir,
logs_dir,
job_name,
slurm,
workers,
partition,
cpus_per_task,
mem_per_cpu,
):
kwargs = {
"pipeline": [
ComputeProgressShards(repo_id, reward_model_path, stride, head_mode, device, str(shard_dir)),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": workers,
"workers": workers,
"time": "24:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": workers, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def make_aggregate_executor(
repo_id,
reward_model_path,
shard_dir,
logs_dir,
job_name,
slurm,
partition,
cpus_per_task,
mem_per_cpu,
push_to_hub,
):
kwargs = {
"pipeline": [
AggregateProgress(repo_id, reward_model_path, str(shard_dir), push_to_hub),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": 1,
"workers": 1,
"time": "02:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
return SlurmPipelineExecutor(**kwargs)
kwargs.update({"tasks": 1, "workers": 1})
return LocalPipelineExecutor(**kwargs)
def _add_shared_args(p):
p.add_argument(
"--repo-id",
type=str,
required=True,
help="Hugging Face repository identifier, e.g. 'user/dataset'.",
)
p.add_argument(
"--shard-dir",
type=Path,
default=Path("rabc_shards"),
help="Directory to read/write per-rank parquet shards.",
)
p.add_argument(
"--logs-dir",
type=Path,
default=Path("logs"),
help="Directory for datatrove logs.",
)
p.add_argument(
"--job-name",
type=str,
default=None,
help="SLURM job name (defaults to rabc_<subcommand>).",
)
p.add_argument(
"--slurm",
type=int,
default=1,
help="1 = submit via SLURM; 0 = run locally (useful for debugging).",
)
p.add_argument(
"--partition",
type=str,
default=None,
help="SLURM partition to submit to.",
)
p.add_argument(
"--cpus-per-task",
type=int,
default=4,
help="Number of CPUs per SLURM task.",
)
p.add_argument(
"--mem-per-cpu",
type=str,
default="4G",
help="Memory per CPU, e.g. '4G' or '1950M'.",
)
def main():
parser = argparse.ArgumentParser(
description="SLURM-distributed SARM RA-BC annotation pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
# compute subcommand
cp = sub.add_parser(
"compute",
help="Distribute progress computation across SLURM workers.",
)
_add_shared_args(cp)
cp.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model.",
)
cp.add_argument(
"--stride",
type=int,
default=1,
help="Compute every Nth frame; intermediate frames are interpolated (must be >= 1).",
)
cp.add_argument(
"--head-mode",
type=str,
default="sparse",
choices=["sparse", "dense", "both"],
help="Which reward head(s) to compute.",
)
cp.add_argument(
"--device",
type=str,
default="cpu",
help="Device for reward model inference, e.g. 'cpu' or 'cuda'.",
)
cp.add_argument(
"--workers",
type=int,
default=50,
help="Number of parallel SLURM tasks (one shard per worker).",
)
# aggregate subcommand
ap = sub.add_parser(
"aggregate",
help="Merge per-rank shards into a single sarm_progress.parquet.",
)
_add_shared_args(ap)
ap.add_argument(
"--reward-model-path",
type=str,
required=True,
help="Path or HF repo id of the SARM reward model (stored in parquet metadata).",
)
ap.add_argument(
"--push-to-hub",
action="store_true",
help="Upload sarm_progress.parquet to the Hugging Face Hub after aggregation.",
)
args = parser.parse_args()
job_name = args.job_name or f"rabc_{args.command}"
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
kwargs["job_name"] = job_name
command = kwargs.pop("command")
executor = make_compute_executor(**kwargs) if command == "compute" else make_aggregate_executor(**kwargs)
executor.run()
if __name__ == "__main__":
main()
@@ -0,0 +1,717 @@
"""
Action consistency analysis for imitation learning datasets.
Two parallel analyses per dataset:
1. State-based: KNN in joint-state space → action chunk variance
2. Image-based: KNN in SigLIP embedding space → action chunk variance
Comparing them reveals whether visual similarity and proprioceptive similarity
agree on where the data is inconsistent — and images are what the policy
primarily sees.
"""
import json
from pathlib import Path
import av
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from huggingface_hub import snapshot_download
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
from scipy.spatial import cKDTree
from transformers import AutoImageProcessor, AutoModel
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
]
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
MAX_FRAMES = 100_000
K_NEIGHBORS = 50
ACTION_CHUNK_SIZE = 30
CAMERA_KEY = "observation.images.base"
ENCODER_MODEL = "google/siglip-base-patch16-224"
ENCODE_BATCH_SIZE = 512
SEED = 42
DPI = 150
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
)
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
LEFT_CHAIN = [
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
RIGHT_CHAIN = [
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
# ── FK math ─────────────────────────────────────────────
def _rot_x(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
def _rot_y(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
def _rot_z(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
r, p, y = rpy
mat = np.eye(4)
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
mat[:3, 3] = xyz
return mat
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
n = len(angles)
ax = np.asarray(axis, dtype=np.float64)
ax = ax / np.linalg.norm(ax)
x, y, z = ax
c = np.cos(angles)
s = np.sin(angles)
t = 1 - c
rot = np.zeros((n, 4, 4))
rot[:, 0, 0] = t * x * x + c
rot[:, 0, 1] = t * x * y - s * z
rot[:, 0, 2] = t * x * z + s * y
rot[:, 1, 0] = t * x * y + s * z
rot[:, 1, 1] = t * y * y + c
rot[:, 1, 2] = t * y * z - s * x
rot[:, 2, 0] = t * x * z - s * y
rot[:, 2, 1] = t * y * z + s * x
rot[:, 2, 2] = t * z * z + c
rot[:, 3, 3] = 1.0
return rot
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
n = joint_angles.shape[0]
tf_batch = np.tile(np.eye(4), (n, 1, 1))
qi = 0
for rpy, xyz, axis in chain:
tf_batch = tf_batch @ _tf(rpy, xyz)
if axis is not None:
rot = _batch_axis_rot(axis, joint_angles[:, qi])
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
qi += 1
return tf_batch[:, :3, 3]
# ── Data helpers ────────────────────────────────────────
def _flatten_names(obj: object) -> list[str]:
if isinstance(obj, dict):
out: list[str] = []
for v in obj.values():
out.extend(_flatten_names(v))
return out
if isinstance(obj, (list, tuple)):
out = []
for item in obj:
if isinstance(item, (list, tuple, dict)):
out.extend(_flatten_names(item))
else:
out.append(str(item))
return out
return [str(obj)]
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
mx = np.max(np.abs(vals))
if mx > 360:
print(f" Unit detection: servo ticks (max={mx:.0f})")
return (vals - 2048) / 2048 * np.pi
if mx > 6.3:
print(f" Unit detection: degrees (max={mx:.1f})")
return np.deg2rad(vals)
print(f" Unit detection: radians (max={mx:.3f})")
return vals.astype(np.float64)
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
feat = features.get("observation.state", features.get(state_col, {}))
names = _flatten_names(feat.get("names", []))
left_idx: list[int] = []
right_idx: list[int] = []
if names and len(names) == n_dim:
names_l = [n.lower() for n in names]
print(f" Feature names: {names[:4]}{names[-4:]}")
for j in range(1, 8):
for i, nm in enumerate(names_l):
if f"left_joint_{j}" in nm and i not in left_idx:
left_idx.append(i)
break
for i, nm in enumerate(names_l):
if f"right_joint_{j}" in nm and i not in right_idx:
right_idx.append(i)
break
if len(left_idx) == 7 and len(right_idx) == 7:
print(f" Matched by name: left={left_idx} right={right_idx}")
return left_idx, right_idx
if n_dim >= 16:
print(" Falling back to positional: [0:7]=left, [8:15]=right")
return list(range(7)), list(range(8, 15))
if n_dim >= 14:
print(" Falling back to positional: [0:7]=left, [7:14]=right")
return list(range(7)), list(range(7, 14))
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
def download_data(repo_id: str, camera_key: str) -> Path:
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=[
"meta/**",
"data/**",
f"videos/{camera_key}/**",
],
)
)
# ── Data loading ────────────────────────────────────────
def _build_action_chunks(
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
) -> tuple[np.ndarray, np.ndarray]:
"""
For each frame, concatenate the next chunk_size actions from the same episode.
Returns (action_chunks, valid_mask).
"""
n = len(actions)
act_dim = actions.shape[1]
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
valid = np.zeros(n, dtype=bool)
for i in range(n):
end = i + chunk_size
if end > n:
continue
if episode_ids[i] != episode_ids[end - 1]:
continue
chunks[i] = actions[i:end].ravel()
valid[i] = True
return chunks, valid
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
"""
Load observation.state and action, build action chunks, subsample, normalize.
Also returns the original row indices (`chosen_idx`) for video frame mapping.
"""
info = json.loads((local / "meta" / "info.json").read_text())
features = info.get("features", {})
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
df = pd.concat(dfs, ignore_index=True)
n_total = len(df)
print(f" Total frames: {n_total:,}")
state_col = next((c for c in df.columns if "observation.state" in c), None)
action_col = next((c for c in df.columns if c == "action"), None)
if state_col is None:
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
if action_col is None:
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
ep_col = next((c for c in df.columns if c == "episode_index"), None)
if ep_col is None:
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
state_all = np.stack(df[state_col].values).astype(np.float64)
action_all = np.stack(df[action_col].values).astype(np.float64)
episode_all = df[ep_col].values.astype(np.int64)
n_dim = state_all.shape[1]
act_dim = action_all.shape[1]
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
print(f" Action chunk dim: {chunk_size * act_dim}")
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
print(" Building action chunks …")
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
valid_idx = np.where(valid)[0]
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
if len(valid_idx) > max_frames:
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
else:
chosen = valid_idx
print(f" Using {len(chosen):,} frames")
state_raw = state_all[chosen]
action_raw = action_chunks[chosen]
episode_ids = episode_all[chosen]
state_mean = state_raw.mean(axis=0)
state_std = state_raw.std(axis=0)
state_std[state_std < 1e-8] = 1.0
state_norm = (state_raw - state_mean) / state_std
action_mean = action_raw.mean(axis=0)
action_std = action_raw.std(axis=0)
action_std[action_std < 1e-8] = 1.0
action_norm = (action_raw - action_mean) / action_std
return {
"state_raw": state_raw,
"state_norm": state_norm,
"action_raw": action_raw,
"action_norm": action_norm,
"episode_ids": episode_ids,
"episode_all": episode_all,
"left_joint_idx": left_idx,
"right_joint_idx": right_idx,
"n_total": n_total,
"chosen_idx": chosen,
"df": df,
}
# ── Video → frame extraction ──────────────────────────────
def build_video_lookup(local: Path, camera_key: str) -> dict:
"""
Build a mapping from episode_index → {video_path, fps, from_ts}.
"""
info = json.loads((local / "meta" / "info.json").read_text())
fps = info["fps"]
video_template = info.get(
"video_path",
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
)
ep_rows = []
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
ep_rows.append(pd.read_parquet(pq))
ep_df = pd.concat(ep_rows, ignore_index=True)
chunk_col = f"videos/{camera_key}/chunk_index"
file_col = f"videos/{camera_key}/file_index"
ts_from = f"videos/{camera_key}/from_timestamp"
if chunk_col not in ep_df.columns:
chunk_col = f"{camera_key}/chunk_index"
file_col = f"{camera_key}/file_index"
ts_from = f"{camera_key}/from_timestamp"
lookup: dict[int, dict] = {}
for _, row in ep_df.iterrows():
ci = int(row[chunk_col])
fi = int(row[file_col])
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
lookup[int(row["episode_index"])] = {
"video_path": local / video_rel,
"from_ts": float(row[ts_from]),
"fps": fps,
}
return lookup
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
container = av.open(video_path)
stream = container.streams.video[0]
stream.thread_type = "AUTO"
decoded = []
for frame in container.decode(stream):
decoded.append(frame.to_ndarray(format="rgb24"))
container.close()
return decoded
def extract_frames(
chosen_idx: np.ndarray,
episode_all: np.ndarray,
video_lookup: dict,
) -> list[np.ndarray | None]:
"""
Extract RGB frames for each chosen global index using PyAV.
Returns list of (H, W, 3) RGB arrays (or None on failure).
"""
unique_eps = np.unique(episode_all)
ep_start: dict[int, int] = {}
for ep in unique_eps:
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
# Build jobs: (output_index, video_path, local_frame_number)
jobs: list[tuple[int, str, int]] = []
for out_i, global_i in enumerate(chosen_idx):
ep = int(episode_all[global_i])
info = video_lookup.get(ep)
if info is None:
continue
local_frame = global_i - ep_start[ep]
jobs.append((out_i, str(info["video_path"]), local_frame))
# Group by video file, decode each video once
from collections import defaultdict
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
for out_i, vpath, local_frame in jobs:
video_jobs[vpath].append((out_i, local_frame))
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
extracted = 0
n_videos = len(video_jobs)
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
if not Path(vpath).exists():
continue
try:
decoded = _decode_video_frames(vpath)
except Exception as exc:
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
continue
for out_i, local_frame in frame_requests:
if 0 <= local_frame < len(decoded):
frames[out_i] = decoded[local_frame]
extracted += 1
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
del decoded
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
return frames
# ── SigLIP encoding ─────────────────────────────────────
def encode_frames_siglip(
frames: list[np.ndarray | None],
model_name: str,
batch_size: int,
device: torch.device,
) -> np.ndarray:
"""
Encode RGB frames through SigLIP vision encoder.
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
"""
print(f" Loading SigLIP model: {model_name}")
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
embed_dim = model.config.vision_config.hidden_size
n = len(frames)
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
valid_indices = [i for i, f in enumerate(frames) if f is not None]
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size}")
for batch_start in range(0, len(valid_indices), batch_size):
batch_idx = valid_indices[batch_start : batch_start + batch_size]
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
inputs = processor(images=pil_images, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_features = torch.nn.functional.normalize(image_features, dim=-1)
embeddings[batch_idx] = image_features.cpu().numpy()
done = min(batch_start + batch_size, len(valid_indices))
if done % (batch_size * 10) == 0 or done == len(valid_indices):
print(f" {done:,} / {len(valid_indices):,} encoded")
del model, processor
torch.cuda.empty_cache()
return embeddings
# ── KNN consistency ─────────────────────────────────────
def compute_consistency(
features: np.ndarray,
action_norm: np.ndarray,
episode_ids: np.ndarray,
k: int,
label: str = "",
) -> np.ndarray:
"""
For each frame, find K nearest neighbors in feature space from other episodes.
Return per-frame action variance (mean across action dims).
"""
n = len(features)
print(f" Building KD-tree on {n:,} vectors ({label}) …")
tree = cKDTree(features)
k_query = min(k * 3, n - 1)
print(f" Querying {k_query} neighbors per frame …")
_dists, indices = tree.query(features, k=k_query + 1)
indices = indices[:, 1:]
print(f" Computing cross-episode action variance ({label}) …")
variance = np.zeros(n)
for i in range(n):
ep_i = episode_ids[i]
neighbors = indices[i]
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
if len(cross_ep) < 2:
variance[i] = 0.0
continue
neighbor_actions = action_norm[cross_ep]
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
return variance
# ── Visualization ───────────────────────────────────────
def _style_ax(ax: plt.Axes) -> None:
ax.set_facecolor("#0d1117")
ax.tick_params(colors="#555", labelsize=8)
for spine in ax.spines.values():
spine.set_color("#333")
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
_style_ax(ax)
median_var = np.median(variance)
mean_var = np.mean(variance)
nonzero = variance[variance > 0]
if len(nonzero) > 0:
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
ax.set_xscale("log")
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
ax.set_ylabel("Frame count", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_episode_curves(
ax: plt.Axes,
var_state: np.ndarray,
var_image: np.ndarray,
episode_ids: np.ndarray,
title: str,
) -> None:
_style_ax(ax)
unique_eps = np.unique(episode_ids)
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
sorted_s = np.sort(ep_means_s)[::-1]
sorted_i = np.sort(ep_means_i)[::-1]
ep_x = np.arange(len(unique_eps))
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
def _plot_heatmap(
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
) -> None:
_style_ax(ax)
order = np.argsort(variance)
pts = tcp_xz[order]
var_sorted = variance[order]
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
sc = ax.scatter(
pts[:, 0],
pts[:, 1],
c=var_sorted,
cmap=CONSISTENCY_CMAP,
s=0.5,
alpha=0.6,
vmin=vmin,
vmax=vmax,
rasterized=True,
)
ax.set_xlabel("X (m)", color="#888", fontsize=10)
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
ax.set_title(title, color="white", fontsize=11, pad=10)
ax.set_aspect("equal")
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Action variance", color="white", fontsize=9)
cbar.ax.tick_params(colors="#aaa", labelsize=7)
def render(results: list[dict], out_path: Path) -> None:
"""
4-row x N-column figure:
Row 0: State-based variance histogram
Row 1: Image-based variance histogram
Row 2: Per-episode curves (both overlaid)
Row 3: Spatial heatmap (image-based variance)
"""
n_ds = len(results)
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
if n_ds == 1:
axes = axes[:, np.newaxis]
headline_parts = []
for col, r in enumerate(results):
label = r["label"]
var_s = r["var_state"]
var_i = r["var_image"]
tcp_xz = r["tcp_xz"]
episode_ids = r["episode_ids"]
med_s = np.median(var_s)
med_i = np.median(var_i)
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
_plot_histogram(
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
)
_plot_episode_curves(
axes[2, col],
var_s,
var_i,
episode_ids,
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
)
_plot_heatmap(
axes[3, col],
fig,
tcp_xz,
var_i,
f"{label}\nImage-based variance by TCP position (XZ)",
)
fig.suptitle(
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
+ " | ".join(headline_parts),
color="white",
fontsize=15,
y=0.99,
)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f"\n✓ Saved: {out_path}")
# ── Main ────────────────────────────────────────────────
def main() -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
rng = np.random.default_rng(SEED)
results = []
for ds in DATASETS:
repo_id, label = ds["repo_id"], ds["label"]
print(f"\n{'=' * 60}")
print(f" {label}: {repo_id}")
print(f"{'=' * 60}")
local = download_data(repo_id, CAMERA_KEY)
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
# --- State-based KNN ---
var_state = compute_consistency(
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
)
print(
f" State variance: median={np.median(var_state):.4f} "
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
)
# --- Image-based KNN ---
print("\n Preparing image embeddings …")
video_lookup = build_video_lookup(local, CAMERA_KEY)
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
del frames # free memory
var_image = compute_consistency(
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
)
print(
f" Image variance: median={np.median(var_image):.4f} "
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
)
# FK for spatial heatmap
print(" Computing FK for spatial heatmap …")
left_raw = data["state_raw"][:, data["left_joint_idx"]]
left_rad = _detect_and_convert(left_raw)
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
tcp_xz = left_tcp[:, [0, 2]]
results.append(
{
"label": label,
"var_state": var_state,
"var_image": var_image,
"episode_ids": data["episode_ids"],
"tcp_xz": tcp_xz,
"n_total": data["n_total"],
}
)
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
render(results, out)
# Save worst-episodes summary (image-based, since that's the stronger signal)
worst_summary = {}
for r in results:
unique_eps = np.unique(r["episode_ids"])
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
worst_path.write_text(json.dumps(worst_summary, indent=2))
print(f"✓ Saved worst episodes: {worst_path}")
if __name__ == "__main__":
main()
@@ -0,0 +1,178 @@
"""
Create a JPG grid of random frames sampled from a LeRobot video dataset.
Downloads metadata + video chunks from HuggingFace, picks random frames,
decodes them, and tiles into a single image.
"""
import json
import random
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
REPO_ID = "lerobot-data-collection/level2_final_quality3"
CAMERA_KEY = "observation.images.base"
GRID_COLS = 15
GRID_ROWS = 10
THUMB_WIDTH = 160
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
SEED = 1
def download_metadata(repo_id: str) -> Path:
"""Download only metadata (no videos yet)."""
print(f"[1/3] Downloading metadata for {repo_id}")
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**"],
ignore_patterns=["*.mp4"],
)
)
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
info = json.loads((local / "meta" / "info.json").read_text())
fps = info["fps"]
features = info["features"]
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
if not video_keys:
raise RuntimeError("No video keys found in dataset features")
if CAMERA_KEY is not None:
if CAMERA_KEY not in video_keys:
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
cam = CAMERA_KEY
else:
cam = video_keys[0]
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
ep_rows = []
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
ep_rows.append(pd.read_parquet(pq))
ep_df = pd.concat(ep_rows, ignore_index=True)
video_template = info.get(
"video_path",
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
)
chunk_col = f"videos/{cam}/chunk_index"
file_col = f"videos/{cam}/file_index"
ts_from = f"videos/{cam}/from_timestamp"
ts_to = f"videos/{cam}/to_timestamp"
if chunk_col not in ep_df.columns:
chunk_col = f"{cam}/chunk_index"
file_col = f"{cam}/file_index"
ts_from = f"{cam}/from_timestamp"
ts_to = f"{cam}/to_timestamp"
episodes = []
for _, row in ep_df.iterrows():
ci = int(row[chunk_col])
fi = int(row[file_col])
episodes.append(
{
"episode_index": int(row["episode_index"]),
"chunk_index": ci,
"file_index": fi,
"from_ts": float(row[ts_from]),
"to_ts": float(row[ts_to]),
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
}
)
return cam, episodes, fps
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
picks = []
for _ in range(n):
ep = rng.choice(episodes)
duration = ep["to_ts"] - ep["from_ts"]
if duration <= 0:
continue
t = ep["from_ts"] + rng.random() * duration
picks.append({**ep, "seek_ts": t})
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
return picks
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
"""Download only the video files we need."""
needed = sorted({p["video_rel"] for p in picks})
print(f"[2/3] Downloading {len(needed)} video file(s) …")
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=str(local),
allow_patterns=needed,
)
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
"""Decode a single frame at the given timestamp."""
cap = cv2.VideoCapture(str(video_path))
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
ret, frame = cap.read()
cap.release()
return frame if ret else None
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
"""Resize frames to uniform thumbnails and tile into a grid."""
if not frames:
raise RuntimeError("No frames decoded")
h0, w0 = frames[0].shape[:2]
thumb_h = int(thumb_w * h0 / w0)
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
rows = []
for i in range(0, len(thumbs), cols):
row_thumbs = thumbs[i : i + cols]
while len(row_thumbs) < cols:
row_thumbs.append(np.zeros_like(row_thumbs[0]))
rows.append(np.hstack(row_thumbs))
return np.vstack(rows)
def main() -> None:
rng = random.Random(SEED)
n_frames = GRID_COLS * GRID_ROWS
local = download_metadata(REPO_ID)
cam, episodes, fps = load_video_info(local)
picks = pick_random_frames(episodes, fps, n_frames, rng)
download_video_files(REPO_ID, local, picks)
print(f"[3/3] Decoding {n_frames} frames …")
frames: list[np.ndarray] = []
for p in picks:
vp = local / p["video_rel"]
if not vp.exists():
print(f" SKIP: {p['video_rel']} not found")
continue
frame = extract_frame(vp, p["seek_ts"])
if frame is not None:
frames.append(frame)
print(f" Decoded {len(frames)}/{n_frames} frames")
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
safe_name = REPO_ID.replace("/", "_")
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
if __name__ == "__main__":
main()
@@ -0,0 +1,526 @@
"""
Create MP4 videos with sarm_progress overlay for specified episodes.
Downloads datasets from HuggingFace, extracts episode video + progress data,
and draws the progress line directly on each frame (no panel, no axes).
"""
import json
import subprocess
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
]
CAMERA_KEY = (
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
)
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
# Progress line spans the full video height
GRAPH_Y_TOP_FRAC = 0.01
GRAPH_Y_BOT_FRAC = 0.99
LINE_THICKNESS = 3
SHADOW_THICKNESS = 6 # white edge thickness
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55
def download_episode(repo_id: str, episode: int) -> Path:
"""Download only the files needed for this episode."""
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
# We'll download meta + sarm first, then figure out chunks.
print(f"\n[1/5] Downloading metadata for {repo_id}")
local = Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"],
ignore_patterns=["*.mp4"],
)
)
return local
def load_episode_meta(local: Path, episode: int) -> dict:
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
info = json.loads((local / "meta" / "info.json").read_text())
fps = info["fps"]
features = info["features"]
# Find video keys (keys whose dtype=="video")
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
if not video_keys:
raise RuntimeError("No video keys found in dataset features")
if CAMERA_KEY is not None:
if CAMERA_KEY not in video_keys:
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
first_cam = CAMERA_KEY
else:
first_cam = video_keys[0]
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
# Load all episode-meta parquet files and find our episode
ep_rows = []
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
df = pd.read_parquet(pq)
ep_rows.append(df)
ep_df = pd.concat(ep_rows, ignore_index=True)
row = ep_df[ep_df["episode_index"] == episode]
if row.empty:
raise RuntimeError(f"Episode {episode} not found in episode metadata")
row = row.iloc[0]
# Extract video chunk/file index for first camera
# Try both dot and slash variants of the key
chunk_col = f"videos/{first_cam}/chunk_index"
file_col = f"videos/{first_cam}/file_index"
ts_col = f"videos/{first_cam}/from_timestamp"
to_col = f"videos/{first_cam}/to_timestamp"
# Some datasets use different column naming
if chunk_col not in row.index:
# Try without the 'videos/' prefix
chunk_col = f"{first_cam}/chunk_index"
file_col = f"{first_cam}/file_index"
ts_col = f"{first_cam}/from_timestamp"
to_col = f"{first_cam}/to_timestamp"
if chunk_col not in row.index:
raise RuntimeError(
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
)
chunk_idx = int(row[chunk_col])
file_idx = int(row[file_col])
from_ts = float(row[ts_col])
to_ts = float(row[to_col])
video_template = info.get(
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
)
video_rel = video_template.format(
video_key=first_cam,
chunk_index=chunk_idx,
file_index=file_idx,
)
# Load task name for this episode
# tasks.parquet uses the task string as the row index; task_index column holds the int id
task_name = ""
try:
# Prefer the 'tasks' list directly on the episode row
if "tasks" in row.index and row["tasks"] is not None:
tasks_val = row["tasks"]
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
task_name = str(tasks_val[0])
else:
task_name = str(tasks_val).strip("[]'")
else:
tasks_pq = local / "meta" / "tasks.parquet"
if tasks_pq.exists():
tasks_df = pd.read_parquet(tasks_pq)
# Row index is the task string; task_index column is the int
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
match = tasks_df[tasks_df["task_index"] == task_idx]
if not match.empty:
task_name = str(match.index[0])
print(f" Task name: '{task_name}'")
except Exception as e:
print(f" WARNING: could not load task name: {e}")
return {
"fps": fps,
"first_cam": first_cam,
"video_rel": video_rel,
"chunk_index": chunk_idx,
"file_index": file_idx,
"from_ts": from_ts,
"to_ts": to_ts,
"task_name": task_name,
}
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
"""Download the specific video file if not already present."""
video_path = local / video_rel
if video_path.exists():
print(f" Video already cached: {video_path}")
return video_path
print(f"[2/5] Downloading video file {video_rel}")
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=str(local),
allow_patterns=[video_rel],
)
if not video_path.exists():
raise RuntimeError(f"Video not found after download: {video_path}")
return video_path
def load_progress(local: Path, episode: int) -> np.ndarray | None:
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
pq_path = local / "sarm_progress.parquet"
if not pq_path.exists():
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
return None
df = pd.read_parquet(pq_path)
print(f" sarm_progress.parquet columns: {list(df.columns)}")
ep_df = df[df["episode_index"] == episode].copy()
if ep_df.empty:
print(f" WARNING: No sarm_progress rows for episode {episode}")
return None
ep_df = ep_df.sort_values("frame_index")
# Prefer dense, fall back to sparse
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
prog_col = "progress_dense"
elif "progress_sparse" in ep_df.columns:
prog_col = "progress_sparse"
else:
# Last resort: any column with 'progress' in the name
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
if not prog_cols:
return None
prog_col = prog_cols[0]
print(f" Using progress column: '{prog_col}'")
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
"""Use ffmpeg to cut the episode segment from the combined video file."""
duration = to_ts - from_ts
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
cmd = [
"ffmpeg",
"-y",
"-ss",
str(from_ts),
"-i",
str(video_path),
"-t",
str(duration),
"-c:v",
"libx264",
"-preset",
"fast",
"-crf",
"18",
"-an",
str(out_path),
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
return out_path
def precompute_pixels(
progress_data: np.ndarray,
n_frames: int,
frame_w: int,
frame_h: int,
) -> np.ndarray:
"""
Map each progress sample to pixel coordinates.
Returns array of shape (N, 2) with (x, y) in pixel space.
x spans full video width; y maps progress [0,1] to graph band.
"""
frame_indices = progress_data[:, 0].astype(float)
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
graph_h = y_bot - y_top
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
# progress=1 → y_top, progress=0 → y_bot
ys = (y_bot - progress_vals * graph_h).astype(int)
return np.stack([xs, ys], axis=1) # (N, 2)
def progress_color(t: float) -> tuple[int, int, int]:
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
r = int(255 * (1.0 - t))
g = int(255 * t)
return (0, g, r) # BGR
def prerender_fill(
pixels: np.ndarray,
frame_w: int,
frame_h: int,
) -> np.ndarray:
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
poly = np.concatenate(
[
pixels,
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
],
axis=0,
).astype(np.int32)
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
return fill_img
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
"""Blend overlay onto base in-place, but only for x < x_max."""
if x_max <= 0:
return
roi_b = base[:, :x_max]
roi_o = overlay_bgra[:, :x_max]
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
roi_b[:] = np.clip(
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
0,
255,
).astype(np.uint8)
def draw_text_outlined(
frame: np.ndarray,
text: str,
pos: tuple[int, int],
font_scale: float,
thickness: int = 1,
) -> None:
"""Draw text with a dark outline for readability on any background."""
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
def composite_video(
clip_path: Path,
progress_data: np.ndarray,
out_path: Path,
fps: float,
frame_h: int,
frame_w: int,
task_name: str = "",
) -> Path:
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
# Pre-render fill polygon (line is drawn per-frame with live color)
fill_img = prerender_fill(pixels, frame_w, frame_h)
# 1.0 reference line overlay (full width, drawn once)
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
frame_indices = progress_data[:, 0].astype(int)
progress_vals = progress_data[:, 1].astype(float)
print(f"[4/4] Compositing {n_total} frames …")
cap = cv2.VideoCapture(str(clip_path))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
fi = 0
while True:
ret, frame = cap.read()
if not ret:
break
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
# 1. reference line (full width, always)
alpha_composite(frame, ref_img, frame_w)
# 2. grey fill under curve up to current x
alpha_composite(frame, fill_img, x_cur)
# 3. progress line — single color that transitions red→green over time
if n_drawn >= 2:
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
line_col = progress_color(t_cur)
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
cv2.polylines(
frame,
[pts],
isClosed=False,
color=(255, 255, 255),
thickness=SHADOW_THICKNESS,
lineType=cv2.LINE_AA,
)
cv2.polylines(
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
)
# 4. score — bottom right
if n_drawn > 0:
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
score_text = f"{score:.2f}"
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
sx = frame_w - tw - 12
sy = frame_h - 12
# coloured score matching current gradient position
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
score_col = progress_color(t_cur)
cv2.putText(
frame,
score_text,
(sx, sy),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
(0, 0, 0),
4,
cv2.LINE_AA,
)
cv2.putText(
frame,
score_text,
(sx, sy),
cv2.FONT_HERSHEY_SIMPLEX,
SCORE_FONT_SCALE,
score_col,
2,
cv2.LINE_AA,
)
# 5. task name — top centre
if task_name:
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
tx = max((frame_w - tw) // 2, 4)
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
writer.write(frame)
fi += 1
if fi % 100 == 0:
print(f" Frame {fi}/{n_total}", end="\r")
cap.release()
writer.release()
print()
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
gif_path = out_path.with_suffix(".gif")
palette = out_path.parent / "_palette.png"
r1 = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(tmp_path),
"-vf",
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
"-update",
"1",
str(palette),
],
capture_output=True,
text=True,
)
if r1.returncode != 0:
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
r2 = subprocess.run( # nosec B607
[
"ffmpeg",
"-y",
"-i",
str(tmp_path),
"-i",
str(palette),
"-filter_complex",
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
str(gif_path),
],
capture_output=True,
text=True,
)
if r2.returncode != 0:
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
tmp_path.unlink(missing_ok=True)
palette.unlink(missing_ok=True)
return gif_path
def process_dataset(repo_id: str, episode: int):
safe_name = repo_id.replace("/", "_")
print(f"\n{'=' * 60}")
print(f"Processing: {repo_id} | episode {episode}")
print(f"{'=' * 60}")
# 1. Download metadata
local = download_episode(repo_id, episode)
print(f" Local cache: {local}")
# 2. Read episode metadata
ep_meta = load_episode_meta(local, episode)
print(f" Episode meta: {ep_meta}")
# 3. Download video file
video_path = download_video(repo_id, local, ep_meta["video_rel"])
# 4. Extract clip
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
# 5. Load progress data
progress_data = load_progress(local, episode)
if progress_data is None:
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
return
n_progress = len(progress_data)
print(f" Progress frames: {n_progress}")
# 6. Get clip dimensions
cap = cv2.VideoCapture(str(clip_path))
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
cap.release()
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
# 7. Composite (draw line directly on frames)
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
final = composite_video(
clip_path,
progress_data,
out_path,
actual_fps,
frame_h,
frame_w,
task_name=ep_meta.get("task_name", ""),
)
clip_path.unlink(missing_ok=True)
print(f"\n✓ Done: {final}")
return final
if __name__ == "__main__":
results = []
for cfg in DATASETS:
try:
out = process_dataset(cfg["repo_id"], cfg["episode"])
if out:
results.append(out)
except Exception as e:
print(f"\nERROR processing {cfg['repo_id']}: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("Output files:")
for r in results:
print(f" {r}")
@@ -0,0 +1,496 @@
"""
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
Downloads joint position data (no videos) from HuggingFace, computes forward
kinematics per episode, clusters trajectories with K-means, and renders
2D projections comparing dataset coverage and multimodality.
"""
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
from sklearn.cluster import KMeans
DATASETS = [
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
]
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
OUTPUT_DIR.mkdir(exist_ok=True)
N_CLUSTERS = 10
WAYPOINTS = 50
SEED = 42
DPI = 180
CLUSTER_COLORS = [
"#e6194b",
"#3cb44b",
"#4363d8",
"#f58231",
"#911eb4",
"#42d4f4",
"#f032e6",
"#bfef45",
"#fabed4",
"#dcbeff",
"#9a6324",
"#fffac8",
"#800000",
"#aaffc3",
"#808000",
"#ffd8b1",
"#000075",
"#a9a9a9",
]
# FK chains extracted from OpenArm bimanual URDF.
# Each entry: (rpy, xyz, revolute_axis_or_None).
LEFT_CHAIN = [
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
RIGHT_CHAIN = [
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
((0, 0, 0), (0, 0, 0.1001), None),
((0, 0, 0), (0, 0, 0.08), None),
]
# ── FK math ─────────────────────────────────────────────
def _rot_x(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
def _rot_y(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
def _rot_z(a: float) -> np.ndarray:
c, s = np.cos(a), np.sin(a)
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
r, p, y = rpy
mat = np.eye(4)
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
mat[:3, 3] = xyz
return mat
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
n = len(angles)
ax = np.asarray(axis, dtype=np.float64)
ax = ax / np.linalg.norm(ax)
x, y, z = ax
c = np.cos(angles)
s = np.sin(angles)
t = 1 - c
rot = np.zeros((n, 4, 4))
rot[:, 0, 0] = t * x * x + c
rot[:, 0, 1] = t * x * y - s * z
rot[:, 0, 2] = t * x * z + s * y
rot[:, 1, 0] = t * x * y + s * z
rot[:, 1, 1] = t * y * y + c
rot[:, 1, 2] = t * y * z - s * x
rot[:, 2, 0] = t * x * z - s * y
rot[:, 2, 1] = t * y * z + s * x
rot[:, 2, 2] = t * z * z + c
rot[:, 3, 3] = 1.0
return rot
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
n = joint_angles.shape[0]
tf_batch = np.tile(np.eye(4), (n, 1, 1))
qi = 0
for rpy, xyz, axis in chain:
tf_batch = tf_batch @ _tf(rpy, xyz)
if axis is not None:
rot = _batch_axis_rot(axis, joint_angles[:, qi])
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
qi += 1
return tf_batch[:, :3, 3]
# ── Data loading ────────────────────────────────────────
def _flatten_names(obj: object) -> list[str]:
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
if isinstance(obj, dict):
out: list[str] = []
for v in obj.values():
out.extend(_flatten_names(v))
return out
if isinstance(obj, (list, tuple)):
out = []
for item in obj:
if isinstance(item, (list, tuple, dict)):
out.extend(_flatten_names(item))
else:
out.append(str(item))
return out
return [str(obj)]
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
mx = np.max(np.abs(vals))
if mx > 360:
print(f" Unit detection: servo ticks (max={mx:.0f})")
return (vals - 2048) / 2048 * np.pi
if mx > 6.3:
print(f" Unit detection: degrees (max={mx:.1f})")
return np.deg2rad(vals)
print(f" Unit detection: radians (max={mx:.3f})")
return vals.astype(np.float64)
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
"""Try to find left/right joint indices from info.json feature names."""
feat = features.get("observation.state", features.get(state_col, {}))
names = _flatten_names(feat.get("names", []))
left_idx: list[int] = []
right_idx: list[int] = []
if names and len(names) == n_dim:
names_l = [n.lower() for n in names]
print(f" Feature names: {names[:4]}{names[-4:]}")
for j in range(1, 8):
for i, nm in enumerate(names_l):
if f"left_joint_{j}" in nm and i not in left_idx:
left_idx.append(i)
break
for i, nm in enumerate(names_l):
if f"right_joint_{j}" in nm and i not in right_idx:
right_idx.append(i)
break
if len(left_idx) == 7 and len(right_idx) == 7:
print(f" Matched by name: left={left_idx} right={right_idx}")
return left_idx, right_idx
if n_dim >= 16:
print(" Falling back to positional: [0:7]=left, [8:15]=right")
return list(range(7)), list(range(8, 15))
if n_dim >= 14:
print(" Falling back to positional: [0:7]=left, [7:14]=right")
return list(range(7)), list(range(7, 14))
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
def download_data(repo_id: str) -> Path:
print(f" Downloading {repo_id} (parquet only) …")
return Path(
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["meta/**", "data/**"],
ignore_patterns=["*.mp4", "videos/**"],
)
)
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
f = traj.shape[0]
if f == n_waypoints:
return traj
old_t = np.linspace(0, 1, f)
new_t = np.linspace(0, 1, n_waypoints)
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
def load_episode_trajectories(local: Path) -> list[dict]:
"""
Load per-episode joint data, compute FK, return list of trajectory dicts.
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
Uses all episodes in the dataset for a fair comparison.
"""
info = json.loads((local / "meta" / "info.json").read_text())
features = info.get("features", {})
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
df = pd.concat(dfs, ignore_index=True)
print(f" Total frames: {len(df):,}")
state_col = next((c for c in df.columns if "observation.state" in c), None)
if state_col is None:
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
first = df[state_col].iloc[0]
if not hasattr(first, "__len__"):
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
state = np.stack(df[state_col].values).astype(np.float64)
n_dim = state.shape[1]
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
ep_col = next((c for c in df.columns if c == "episode_index"), None)
if ep_col is None:
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
episode_ids = df[ep_col].values
unique_eps = np.unique(episode_ids)
print(f" Episodes: {len(unique_eps):,}")
left_raw = state[:, left_idx]
right_raw = state[:, right_idx]
left_all = _detect_and_convert(left_raw)
right_all = _detect_and_convert(right_raw)
print(" Computing FK per episode …")
trajectories = []
for ep_id in unique_eps:
mask = episode_ids == ep_id
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
if len(left_tcp) < 3:
continue
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
print(f" Valid trajectories: {len(trajectories):,}")
return trajectories
# ── Clustering ──────────────────────────────────────────
def cluster_trajectories(
trajectories: list[dict], n_clusters: int, n_waypoints: int
) -> tuple[np.ndarray, np.ndarray]:
"""
K-means on resampled trajectory features.
Combines left+right TCP into a single feature vector per episode.
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
"""
feat_vecs = []
for t in trajectories:
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
feat_matrix = np.array(feat_vecs)
k = min(n_clusters, len(feat_vecs))
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
labels = km.fit_predict(feat_matrix)
centroids_flat = km.cluster_centers_
centroid_trajs = np.zeros((k, n_waypoints, 6))
for ci in range(k):
left_flat = centroids_flat[ci, : n_waypoints * 3]
right_flat = centroids_flat[ci, n_waypoints * 3 :]
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
# Mean per-waypoint distance to centroid (in metres) for each cluster
spread = np.zeros(k)
for ci in range(k):
members = np.where(labels == ci)[0]
if len(members) == 0:
continue
centroid_left = centroid_trajs[ci, :, :3]
centroid_right = centroid_trajs[ci, :, 3:]
dists = []
for mi in members:
t = trajectories[mi]
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
dists.append((d_left + d_right) / 2)
spread[ci] = np.mean(dists)
return labels, centroid_trajs, spread
# ── Visualization ───────────────────────────────────────
PROJ_VIEWS = [
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
("XY (top)", 0, 1, "X (m)", "Y (m)"),
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
]
def render(results: list[dict], out_path: Path) -> None:
"""
2-row × 3-col grid per dataset (3 projections × 2 datasets).
Trajectory lines colored by cluster, centroid trajectories drawn thick.
"""
n_ds = len(results)
n_proj = len(PROJ_VIEWS)
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
if n_ds == 1:
axes = axes[np.newaxis, :]
for row, r in enumerate(results):
trajectories = r["trajectories"]
labels = r["labels"]
centroids = r["centroids"]
k = centroids.shape[0]
cluster_sizes = np.bincount(labels, minlength=k)
size_order = np.argsort(-cluster_sizes)
pcts = cluster_sizes / len(labels) * 100
spread = r["spread"]
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
ax = axes[row, col]
ax.set_facecolor("#0d1117")
for ti, traj in enumerate(trajectories):
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
for tcp_key in ("left_tcp", "right_tcp"):
pts = traj[tcp_key]
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
for ci in range(k):
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
left_c = centroids[ci, :, :3]
right_c = centroids[ci, :, 3:]
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
for c_pts in (left_c, right_c):
ax.plot(
c_pts[:, dim_a],
c_pts[:, dim_b],
color=color,
linewidth=lw,
alpha=0.95,
zorder=10,
)
ax.plot(
c_pts[0, dim_a],
c_pts[0, dim_b],
"o",
color=color,
markersize=4,
zorder=11,
)
ax.plot(
c_pts[-1, dim_a],
c_pts[-1, dim_b],
"s",
color=color,
markersize=4,
zorder=11,
)
ax.set_xlabel(xlabel, color="#888", fontsize=9)
ax.set_ylabel(ylabel, color="#888", fontsize=9)
ax.tick_params(colors="#555", labelsize=7)
for spine in ax.spines.values():
spine.set_color("#333")
ax.set_aspect("equal")
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
if col == 0:
ax.set_title(
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
f"avg spread {mean_spread_cm:.1f}cm)",
color="white",
fontsize=11,
pad=10,
)
else:
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
# Cluster size + spread legend on the rightmost panel
legend_ax = axes[row, -1]
for ci in size_order:
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
spread_cm = spread[ci] * 100
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
legend_ax.plot([], [], color=color, linewidth=3, label=label)
legend_ax.legend(
loc="upper right",
fontsize=7,
frameon=True,
facecolor="#1a1a2e",
edgecolor="#333",
labelcolor="white",
handlelength=1.5,
)
fig.suptitle(
"End-Effector Trajectory Clusters (FK · K-means)",
color="white",
fontsize=16,
y=0.98,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close()
print(f"\n✓ Saved: {out_path}")
# ── Main ────────────────────────────────────────────────
def main() -> None:
results = []
for ds in DATASETS:
repo_id, label = ds["repo_id"], ds["label"]
print(f"\n{'=' * 60}")
print(f" {label}: {repo_id}")
print(f"{'=' * 60}")
local = download_data(repo_id)
trajectories = load_episode_trajectories(local)
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
for ci in np.argsort(-cluster_sizes):
print(
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
f"spread ±{spread[ci] * 100:.1f}cm"
)
results.append(
{
"label": label,
"trajectories": trajectories,
"labels": labels,
"centroids": centroids,
"spread": spread,
"n_episodes": len(trajectories),
}
)
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
render(results, out)
if __name__ == "__main__":
main()
+1 -1
View File
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import make_default_processors
+1 -1
View File
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
+2 -3
View File
@@ -16,15 +16,13 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import (
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+3 -2
View File
@@ -15,11 +15,11 @@
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
@@ -38,6 +38,7 @@ from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -1
View File
@@ -18,7 +18,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -27,6 +27,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
+2 -1
View File
@@ -16,7 +16,7 @@
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -31,6 +31,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+2 -1
View File
@@ -22,7 +22,8 @@ from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
DROID_SHARDS = 2048
+2 -2
View File
@@ -26,7 +26,7 @@ from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.utils import init_logging
@@ -155,7 +155,7 @@ class UploadDataset(PipelineStep):
from datasets.utils.tqdm import disable_progress_bars
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.utils.utils import init_logging
init_logging()
+2 -1
View File
@@ -113,8 +113,9 @@ from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
+3 -1
View File
@@ -78,10 +78,11 @@ from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@@ -97,6 +98,7 @@ from lerobot.robots import ( # noqa: F401
bi_so_follower,
koch_follower,
so_follower,
unitree_g1,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
+2 -3
View File
@@ -16,15 +16,13 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import (
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+3 -2
View File
@@ -16,11 +16,11 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
@@ -35,6 +35,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -1
View File
@@ -19,7 +19,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -28,6 +28,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
+2 -1
View File
@@ -17,7 +17,7 @@
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
robot_action_to_transition,
@@ -30,6 +30,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+3 -2
View File
@@ -19,8 +19,9 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
+2 -2
View File
@@ -20,9 +20,9 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
@@ -5,8 +5,9 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
+1 -1
View File
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
@@ -5,8 +5,9 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
+1 -1
View File
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
+1 -1
View File
@@ -6,8 +6,8 @@ from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
+53 -35
View File
@@ -25,11 +25,11 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.4"
version = "0.5.1"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
dynamic = ["readme"]
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
requires-python = ">=3.12"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,7 +50,8 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -61,26 +62,28 @@ dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<5.0.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[cli]>=1.0.0,<2.0.0",
"huggingface-hub>=1.0.0,<2.0.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"setuptools>=71.0.0,<81.0.0",
"cmake>=3.29.0.1,<4.2.0",
"packaging>=24.2,<26.0",
"torch>=2.2.1,<2.11.0",
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"opencv-python-headless>=4.9.0,<4.14.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pynput>=1.7.8,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.24.0,<0.25.0",
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
"draccus==0.10.0", # TODO: Remove ==
"draccus==0.10.0", # TODO: Relax version constraint
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
@@ -95,10 +98,14 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=5.1.0,<6.0.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -112,34 +119,35 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"onnx>=1.16.0,<2.0.0",
"meshcat>=0.3.0,<0.4.0",
"matplotlib>=3.9.0,<4.0.0",
"casadi>=3.6.0,<4.0.0",
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
"pyrealsense2-macosx>=2.54,<2.57.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
# Policies
wallx = [
"lerobot[transformers-dep]",
"peft>=0.18.0,<1.0.0",
"scipy==1.15.3", # TODO: Relax version
"torchdiffeq==0.2.5", # TODO: Relax version
"qwen-vl-utils==0.0.11" # TODO: Relax version
"lerobot[peft]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "scipy==1.15.3"] # TODO: Relax scipy version
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"lerobot[peft]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
@@ -148,13 +156,13 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.11,<0.1.0"]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -162,13 +170,19 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
all = [
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
# helps pip's resolver converge by constraining scipy early, before it encounters
# the loose scipy requirements from transitive deps like dm-control and metaworld.
"scipy>=1.14.0,<2.0.0",
"lerobot[dynamixel]",
"lerobot[gamepad]",
"lerobot[hopejr]",
@@ -189,10 +203,11 @@ all = [
"lerobot[aloha]",
"lerobot[pusht]",
"lerobot[phone]",
"lerobot[libero]",
"lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]",
"lerobot[sarm]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
[project.scripts]
@@ -214,11 +229,14 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
[tool.ruff]
target-version = "py310"
target-version = "py312"
line-length = 110
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
@@ -310,7 +328,7 @@ default.extend-ignore-identifiers-re = [
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
[tool.mypy]
python_version = "3.10"
python_version = "3.12"
ignore_missing_imports = true
follow_imports = "skip"
# warn_return_any = true
+170 -271
View File
@@ -1,76 +1,73 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements-macos.txt requirements.in
#
-e .[all]
# via -[all]
absl-py==2.3.1
absl-py==2.4.0
# via
# dm-control
# dm-env
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
accelerate==1.13.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.13.3
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
anyio==4.12.1
# via
# httpx
# starlette
# watchfiles
asttokens==3.0.0
asttokens==3.0.1
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# lerobot
# qwen-vl-utils
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
# sentry-sdk
cffi==2.0.0
# via pymunk
cfgv==3.4.0
cfgv==3.5.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.5
# via requests
click==8.3.0
click==8.3.1
# via
# typer
# uvicorn
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
cloudpickle==3.1.2
# via gymnasium
cmake==4.1.3
# via lerobot
cmeel==0.57.3
cmeel==0.59.0
# via
# cmeel-assimp
# cmeel-boost
@@ -108,15 +105,17 @@ cmeel-zlib==1.3.1
# via cmeel-assimp
coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
contourpy==1.3.3
# via
# lerobot
# matplotlib
coverage[toml]==7.13.4
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==4.6.1
# via lerobot
debugpy==1.8.17
debugpy==1.8.20
# via lerobot
decorator==5.2.1
# via ipython
@@ -130,7 +129,7 @@ dill==0.4.0
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.37
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -138,69 +137,55 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# lerobot
# libero
einops==0.8.2
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
etils[epath,epy]==1.14.0
# via mujoco
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
# via stack-data
faker==34.0.2
# via lerobot
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
fastapi==0.135.1
# via
# lerobot
# teleop
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.25.0
# via
# datasets
# diffusers
# huggingface-hub
# python-discovery
# torch
# transformers
# virtualenv
fonttools==4.60.1
fonttools==4.61.1
# via matplotlib
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2026.2.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
gitpython==3.1.46
# via wandb
glfw==2.10.0
# via
@@ -212,7 +197,6 @@ grpcio==1.73.1
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
@@ -223,71 +207,67 @@ gym-hil==0.1.13
# via lerobot
gym-pusht==0.1.6
# via lerobot
gymnasium==1.2.1
gymnasium==1.2.3
# via
# gym-aloha
# gym-hil
# gym-pusht
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
# via
# httpcore
# uvicorn
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.3.2
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httpcore==1.0.9
# via httpx
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
httpx==0.28.1
# via
# datasets
# huggingface-hub
huggingface-hub==1.6.0
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.17
# via pre-commit
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
imageio[ffmpeg]==2.37.0
imageio[ffmpeg]==2.37.2
# via
# gym-aloha
# gym-hil
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
importlib-metadata==8.7.0
# via imageio
importlib-metadata==8.7.1
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
ipython==8.37.0
ipython==9.11.0
# via meshcat
ipython-pygments-lexers==1.1.1
# via ipython
ischedule==1.2.7
# via placo
jedi==0.19.2
@@ -296,44 +276,24 @@ jinja2==3.1.6
# via torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
lazy-loader==0.5
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
librt==0.8.1
# via mypy
lxml==6.0.2
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
# via rich
markupsafe==3.0.3
# via
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
# via jinja2
matplotlib==3.10.8
# via lerobot
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
@@ -346,41 +306,35 @@ mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==3.5.0
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.16
multiprocess==0.70.18
# via datasets
mypy==1.19.1
# via lerobot
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# mypy
# typing-inspect
networkx==3.6.1
# via
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
nodeenv==1.10.0
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
@@ -389,16 +343,14 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# imageio
# labmaze
# libero
# lerobot
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
@@ -406,26 +358,18 @@ numpy==2.2.6
# pyquaternion
# reachy2-sdk
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
opencv-python==4.13.0.92
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -435,97 +379,87 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# qwen-vl-utils
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.6
# via jedi
peft==0.17.1
pathspec==1.0.4
# via mypy
peft==0.18.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==12.1.1
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# qwen-vl-utils
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
placo==0.9.16
# via lerobot
platformdirs==4.5.0
platformdirs==4.9.4
# via
# jupyter-core
# python-discovery
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.5.1
# via lerobot
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
# via ipython
propcache==0.4.1
# via
# aiohttp
# yarl
protobuf==6.31.0
protobuf==6.31.1
# via
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.2.2
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pyarrow==21.0.0
pyarrow==23.0.1
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==3.0
# via cffi
pydantic==2.12.3
pydantic==2.12.5
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic-core==2.41.5
# via pydantic
pygame==2.6.1
# via
@@ -535,33 +469,35 @@ pygame==2.6.1
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
# pytest
# rich
pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.5.1
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyobjc-core==12.0
pyobjc-core==12.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-applicationservices==12.0
pyobjc-framework-applicationservices==12.1
# via pynput
pyobjc-framework-cocoa==12.0
pyobjc-framework-cocoa==12.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-coretext==12.0
pyobjc-framework-coretext==12.1
# via pyobjc-framework-applicationservices
pyobjc-framework-quartz==12.0
pyobjc-framework-quartz==12.1
# via
# pynput
# pyobjc-framework-applicationservices
@@ -570,13 +506,13 @@ pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.3.2
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2-macosx==2.54.2
pyrealsense2-macosx==2.56.5
# via lerobot
pyserial==3.5
# via
@@ -585,7 +521,6 @@ pyserial==3.5
# lerobot
pytest==8.4.2
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
@@ -596,11 +531,14 @@ pytest-timeout==2.4.0
# via lerobot
python-dateutil==2.9.0.post0
# via
# faker
# matplotlib
# pandas
python-dotenv==1.1.1
python-discovery==1.1.1
# via virtualenv
python-dotenv==1.2.2
# via uvicorn
pytz==2025.2
pytz==2026.1.post1
# via pandas
pyyaml==6.0.3
# via
@@ -609,13 +547,10 @@ pyyaml==6.0.3
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
@@ -625,15 +560,13 @@ pyzmq==27.1.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
qwen-vl-utils==0.0.14
# via lerobot
reachy2-sdk==1.0.15
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2026.2.28
# via
# diffusers
# transformers
@@ -642,184 +575,150 @@ requests==2.32.5
# datasets
# diffusers
# dm-control
# huggingface-hub
# qwen-vl-utils
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.26.2
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
rich==14.3.3
# via typer
safetensors==0.7.0
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
# gym-pusht
# lerobot
scipy==1.15.3
scipy==1.17.1
# via
# dm-control
# lerobot
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
# torchdiffeq
sentry-sdk==2.54.0
# via wandb
shapely==2.1.2
# via gym-pusht
shellingham==1.5.4
# via typer
six==1.17.0
# via
# pynput
# python-dateutil
smmap==5.0.2
smmap==5.0.3
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
starlette==0.52.1
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
teleop==0.1.4
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
termcolor==3.3.0
# via lerobot
tifffile==2026.3.3
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.22.2
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
torch==2.10.0
# via
# accelerate
# lerobot
# peft
# robomimic
# thop
# timm
# torchdiffeq
# torchvision
torchcodec==0.5
torchcodec==0.10.0
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
torchdiffeq==0.2.5
# via lerobot
torchvision==0.25.0
# via lerobot
tornado==6.5.4
# via meshcat
tqdm==4.67.1
tqdm==4.67.3
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
transformers==5.3.0
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
typer==0.24.1
# via
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# faker
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# mypy
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via
# fastapi
# pydantic
tzdata==2025.3
# via pandas
u-msgpack-python==2.8.0
# via meshcat
urllib3==2.5.0
urllib3==2.6.3
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
uvicorn[standard]==0.41.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==21.1.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
wandb==0.24.2
# via lerobot
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wcwidth==0.6.0
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
websockets==16.0
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
wrapt==2.1.2
# via dm-tree
xxhash==3.6.0
# via datasets
yarl==1.22.0
yarl==1.23.0
# via aiohttp
zipp==3.23.0
# via
+209 -188
View File
@@ -1,12 +1,12 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
#
-e .[all]
# via -[all]
absl-py==2.3.1
absl-py==2.4.0
# via
# dm-control
# dm-env
@@ -14,30 +14,33 @@ absl-py==2.3.1
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
accelerate==1.13.0
# via
# lerobot
# peft
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.13.3
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-doc==0.0.4
# via
# fastapi
# typer
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
anyio==4.12.1
# via
# httpx
# starlette
# watchfiles
asttokens==3.0.0
asttokens==3.0.1
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
# via
# aiohttp
@@ -47,30 +50,35 @@ attrs==25.4.0
# referencing
# rerun-sdk
av==15.1.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
# via
# lerobot
# qwen-vl-utils
bddl==1.0.1
# via hf-libero
certifi==2026.2.25
# via
# httpcore
# httpx
# requests
# sentry-sdk
cffi==2.0.0
# via pymunk
cfgv==3.4.0
cfgv==3.5.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.5
# via requests
click==8.3.0
click==8.3.1
# via
# typer
# uvicorn
# wandb
cloudpickle==3.1.1
cloudpickle==3.1.2
# via
# gymnasium
# libero
cmake==4.1.0
# hf-libero
cmake==4.1.3
# via lerobot
cmeel==0.57.3
cmeel==0.59.0
# via
# cmeel-assimp
# cmeel-boost
@@ -108,20 +116,24 @@ cmeel-zlib==1.3.1
# via cmeel-assimp
coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
contourpy==1.3.3
# via
# lerobot
# matplotlib
coverage[toml]==7.13.4
# via pytest-cov
cuda-bindings==12.9.4
# via torch
cuda-pathfinder==1.4.1
# via cuda-bindings
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==4.6.1
# via lerobot
debugpy==1.8.17
debugpy==1.8.20
# via lerobot
decorator==5.2.1
# via ipython
decord==0.6.0
# via lerobot
deepdiff==8.6.1
# via lerobot
diffusers==0.35.2
@@ -132,7 +144,7 @@ dill==0.4.0
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.37
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -140,7 +152,6 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
@@ -148,66 +159,60 @@ draccus==0.10.0
dynamixel-sdk==3.8.4
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
# via hf-libero
egl-probe==1.0.2
# via robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
einops==0.8.2
# via
# flash-attn
# hf-libero
# lerobot
# libero
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
etils[epath,epy]==1.14.0
# via mujoco
evdev==1.9.2
evdev==1.9.3
# via pynput
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
# via stack-data
faker==34.0.2
# via lerobot
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastapi==0.135.1
# via
# lerobot
# teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.25.0
# via
# datasets
# diffusers
# huggingface-hub
# python-discovery
# torch
# transformers
# virtualenv
flash-attn==2.8.3
# via lerobot
fonttools==4.60.1
fonttools==4.61.1
# via matplotlib
frozenlist==1.8.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2026.2.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
# via hf-libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
gitpython==3.1.46
# via wandb
glfw==2.10.0
# via
@@ -230,50 +235,60 @@ gym-hil==0.1.13
# via lerobot
gym-pusht==0.1.6
# via lerobot
gymnasium==1.2.1
gymnasium==1.2.3
# via
# gym-aloha
# gym-hil
# gym-pusht
# hf-libero
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via
# httpcore
# uvicorn
h5py==3.16.0
# via robomimic
hebi-py==2.11.0
# via lerobot
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-egl-probe==1.0.2
# via hf-libero
hf-libero==0.1.3
# via lerobot
hf-xet==1.3.2
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httpcore==1.0.9
# via httpx
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
httpx==0.28.1
# via
# datasets
# huggingface-hub
huggingface-hub==1.6.0
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
# via hf-libero
identify==2.6.17
# via pre-commit
idna==3.11
# via
# anyio
# httpx
# requests
# yarl
imageio[ffmpeg]==2.37.0
imageio[ffmpeg]==2.37.2
# via
# gym-aloha
# gym-hil
@@ -285,16 +300,14 @@ imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
importlib-metadata==8.7.0
importlib-metadata==8.7.1
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
ipython==8.37.0
ipython==9.11.0
# via meshcat
ipython-pygments-lexers==1.1.1
# via ipython
ischedule==1.2.7
# via placo
jedi==0.19.2
@@ -303,40 +316,41 @@ jinja2==3.1.6
# via torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
jsonschema==4.26.0
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
jupytext==1.19.1
# via bddl
kiwisolver==1.4.9
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
lazy-loader==0.5
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
librt==0.8.1
# via mypy
llvmlite==0.46.0
# via numba
lxml==6.0.2
# via dm-control
markdown==3.9
markdown==3.10.2
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
# rich
markupsafe==3.0.3
# via
# jinja2
# werkzeug
matplotlib==3.10.7
matplotlib==3.10.8
# via
# hf-libero
# lerobot
# libero
matplotlib-inline==0.2.1
# via ipython
mdit-py-plugins==0.5.0
@@ -353,36 +367,38 @@ mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==3.5.0
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# hf-libero
# metaworld
# robosuite
multidict==6.7.0
multidict==6.7.1
# via
# aiohttp
# yarl
multiprocess==0.70.16
multiprocess==0.70.18
# via datasets
mypy==1.19.1
# via lerobot
mypy-extensions==1.1.0
# via typing-inspect
# via
# mypy
# typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
networkx==3.6.1
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
nodeenv==1.10.0
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
numba==0.64.0
# via robosuite
numpy==2.2.6
# via
@@ -391,7 +407,6 @@ numpy==2.2.6
# cmeel-boost
# contourpy
# datasets
# decord
# diffusers
# dm-control
# dm-env
@@ -399,9 +414,10 @@ numpy==2.2.6
# gymnasium
# h5py
# hebi-py
# hf-libero
# imageio
# labmaze
# libero
# lerobot
# matplotlib
# meshcat
# metaworld
@@ -426,49 +442,51 @@ numpy==2.2.6
# torchvision
# transformers
# transforms3d
nvidia-cublas-cu12==12.6.4.1
nvidia-cublas-cu12==12.8.4.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-cupti-cu12==12.8.90
# via torch
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-nvrtc-cu12==12.8.93
# via torch
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.8.90
# via torch
nvidia-cudnn-cu12==9.5.1.17
nvidia-cudnn-cu12==9.10.2.21
# via torch
nvidia-cufft-cu12==11.3.0.4
nvidia-cufft-cu12==11.3.3.83
# via torch
nvidia-cufile-cu12==1.11.1.6
nvidia-cufile-cu12==1.13.1.3
# via torch
nvidia-curand-cu12==10.3.7.77
nvidia-curand-cu12==10.3.9.90
# via torch
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusolver-cu12==11.7.3.90
# via torch
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparse-cu12==12.5.8.93
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.3
nvidia-cusparselt-cu12==0.7.1
# via torch
nvidia-nccl-cu12==2.26.2
nvidia-nccl-cu12==2.27.5
# via torch
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvjitlink-cu12==12.8.93
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.6.77
nvidia-nvshmem-cu12==3.4.5
# via torch
nvidia-nvtx-cu12==12.8.90
# via torch
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
opencv-python==4.13.0.92
# via
# gym-pusht
# libero
# hf-libero
# reachy2-sdk
# robosuite
opencv-python-headless==4.12.0.88
@@ -487,6 +505,7 @@ packaging==25.0
# matplotlib
# peft
# pytest
# qwen-vl-utils
# reachy2-sdk
# scikit-image
# tensorboard
@@ -497,21 +516,21 @@ pandas==2.3.3
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.6
# via jedi
peft==0.17.1
pathspec==1.0.4
# via mypy
peft==0.18.1
# via lerobot
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==12.1.1
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# qwen-vl-utils
# rerun-sdk
# robosuite
# scikit-image
@@ -519,28 +538,27 @@ pillow==12.0.0
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
placo==0.9.16
# via lerobot
platformdirs==4.5.0
platformdirs==4.9.4
# via
# jupyter-core
# python-discovery
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.5.1
# via lerobot
prompt-toolkit==3.0.52
# via
# inquirerpy
# ipython
# via ipython
propcache==0.4.1
# via
# aiohttp
# yarl
protobuf==6.31.0
protobuf==6.31.1
# via
# dm-control
# grpcio-tools
@@ -550,7 +568,7 @@ protobuf==6.31.0
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.2.2
# via
# accelerate
# imageio
@@ -560,17 +578,17 @@ ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pyarrow==21.0.0
pyarrow==23.0.1
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==3.0
# via cffi
pydantic==2.12.3
pydantic==2.12.5
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic-core==2.41.5
# via pydantic
pygame==2.6.1
# via
@@ -580,12 +598,14 @@ pygame==2.6.1
pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
# pytest
# rich
pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.5.1
# via meshcat
pynput==1.8.1
# via
@@ -595,7 +615,7 @@ pyopengl==3.1.10
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.3.2
# via
# dm-control
# matplotlib
@@ -621,13 +641,16 @@ pytest-timeout==2.4.0
# via lerobot
python-dateutil==2.9.0.post0
# via
# faker
# matplotlib
# pandas
python-dotenv==1.1.1
python-discovery==1.1.1
# via virtualenv
python-dotenv==1.2.2
# via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
pytz==2026.1.post1
# via pandas
pyyaml==6.0.3
# via
@@ -642,7 +665,6 @@ pyyaml==6.0.3
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
@@ -652,7 +674,9 @@ pyzmq==27.1.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
qwen-vl-utils==0.0.14
# via lerobot
reachy2-sdk==1.0.15
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
@@ -660,7 +684,7 @@ referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2026.2.28
# via
# diffusers
# transformers
@@ -669,60 +693,62 @@ requests==2.32.5
# datasets
# diffusers
# dm-control
# huggingface-hub
# qwen-vl-utils
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.26.2
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
rich==14.3.3
# via typer
robomimic==0.2.0
# via libero
# via hf-libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via hf-libero
rpds-py==0.30.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.7.0
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
# gym-pusht
# lerobot
scipy==1.15.3
scipy==1.17.1
# via
# dm-control
# lerobot
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
# torchdiffeq
sentry-sdk==2.54.0
# via wandb
shapely==2.1.2
# via gym-pusht
shellingham==1.5.4
# via typer
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
smmap==5.0.2
smmap==5.0.3
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
starlette==0.52.1
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
teleop==0.1.4
# via lerobot
tensorboard==2.20.0
# via robomimic
@@ -730,46 +756,38 @@ tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
termcolor==3.3.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
tifffile==2025.5.10
# via hf-libero
tifffile==2026.3.3
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.22.2
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
torch==2.10.0
# via
# accelerate
# flash-attn
# lerobot
# peft
# robomimic
# thop
# timm
# torchdiffeq
# torchvision
torchcodec==0.5
torchcodec==0.10.0
# via lerobot
torchvision==0.22.1
torchdiffeq==0.2.5
# via lerobot
torchvision==0.25.0
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
tornado==6.5.4
# via meshcat
tqdm==4.67.1
tqdm==4.67.3
# via
# datasets
# dm-control
@@ -783,26 +801,29 @@ traitlets==5.14.3
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
transformers==5.3.0
# via
# hf-libero
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
triton==3.3.1
triton==3.6.0
# via torch
typer==0.24.1
# via
# huggingface-hub
# transformers
typing-extensions==4.15.0
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# faker
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# mypy
# pydantic
# pydantic-core
# referencing
@@ -811,46 +832,46 @@ typing-extensions==4.15.0
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
# via pydantic
tzdata==2025.2
# via
# fastapi
# pydantic
tzdata==2025.3
# via pandas
u-msgpack-python==2.8.0
# via meshcat
urllib3==2.5.0
urllib3==2.6.3
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
uvicorn[standard]==0.41.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==21.1.0
# via pre-commit
wandb==0.21.4
wandb==0.24.2
# via
# hf-libero
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wcwidth==0.6.0
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
websockets==16.0
# via uvicorn
werkzeug==3.1.3
werkzeug==3.1.6
# via tensorboard
wrapt==2.0.0
wrapt==2.1.2
# via dm-tree
xxhash==3.6.0
# via datasets
yarl==1.22.0
yarl==1.23.0
# via aiohttp
zipp==3.23.0
# via
+4 -4
View File
@@ -1,9 +1,9 @@
# requirements.in
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.3.1 25D2128 arm64).
# Darwin MacBook-Pro.local 25.3.0 Darwin Kernel Version 25.3.0: Wed Jan 28 20:54:55 PST 2026; root:xnu-12377.91.3~2/RELEASE_ARM64_T8132 arm64
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.4 LTS x86_64).
# Linux lerobot-linux 6.17.0-14-generic #14~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Jan 15 15:52:10 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]
+1 -1
View File
@@ -23,7 +23,7 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ( # noqa: F401
+2 -4
View File
@@ -39,15 +39,13 @@ import grpc
import torch
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
)
from lerobot.processor import PolicyProcessorPipeline
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
from lerobot.types import PolicyAction
from .configs import PolicyServerConfig
from .constants import SUPPORTED_POLICIES
+5 -3
View File
@@ -63,9 +63,9 @@ from lerobot.transport import (
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from lerobot.utils.import_utils import register_third_party_plugins
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
@@ -485,8 +485,9 @@ class RobotClient:
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
# TODO: Assert if checking robot support is still needed with the plugin system
# if cfg.robot.type not in SUPPORTED_ROBOTS:
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
@@ -512,4 +513,5 @@ def async_client(cfg: RobotClientConfig):
if __name__ == "__main__":
register_third_party_plugins()
async_client() # run the client
+1 -1
View File
@@ -181,7 +181,7 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
+72 -4
View File
@@ -23,6 +23,7 @@ import base64
import contextlib
import json
import logging
import threading
import time
from collections import deque
@@ -42,10 +43,57 @@ def encode_image(image: np.ndarray, quality: int = 80) -> str:
return base64.b64encode(buffer).decode("utf-8")
class CameraCaptureThread:
"""Background thread that continuously captures and encodes frames from a camera."""
def __init__(self, camera: OpenCVCamera, name: str):
self.camera = camera
self.name = name
self.latest_encoded: str | None = None # Pre-encoded JPEG as base64
self.latest_timestamp: float = 0.0
self.frame_lock = threading.Lock()
self.running = False
self.thread: threading.Thread | None = None
def start(self):
"""Start the capture thread."""
self.running = True
self.thread = threading.Thread(target=self._capture_loop, daemon=True)
self.thread.start()
def stop(self):
"""Stop the capture thread."""
self.running = False
if self.thread:
self.thread.join(timeout=1.0)
def _capture_loop(self):
"""Continuously capture and encode frames at the camera's native rate."""
while self.running:
try:
frame = self.camera.read() # Blocks at camera's native rate
timestamp = time.time()
# Encode immediately in capture thread (this is the slow part)
encoded = encode_image(frame)
with self.frame_lock:
self.latest_encoded = encoded
self.latest_timestamp = timestamp
except Exception as e:
logger.warning(f"Camera {self.name} capture error: {e}")
time.sleep(0.01)
def get_latest(self) -> tuple[str | None, float]:
"""Get the latest encoded frame and its timestamp."""
with self.frame_lock:
return self.latest_encoded, self.latest_timestamp
class ImageServer:
def __init__(self, config: dict, port: int = 5555):
# fps controls the publish loop rate (how often frames are sent over ZMQ), not the camera capture rate
self.fps = config.get("fps", 30)
self.cameras: dict[str, OpenCVCamera] = {}
self.capture_threads: dict[str, CameraCaptureThread] = {}
for name, cfg in config.get("cameras", {}).items():
shape = cfg.get("shape", [480, 640])
@@ -61,6 +109,10 @@ class ImageServer:
self.cameras[name] = camera
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
# Create capture thread for this camera
capture_thread = CameraCaptureThread(camera, name)
self.capture_threads[name] = capture_thread
# ZMQ PUB socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
@@ -73,6 +125,18 @@ class ImageServer:
def run(self):
frame_count = 0
frame_times = deque(maxlen=60)
last_published_ts: dict[str, float] = {}
# Start all capture threads
for capture_thread in self.capture_threads.values():
capture_thread.start()
# Wait for first frames to be captured and encoded
logger.info("Waiting for cameras to start capturing...")
for name, capture_thread in self.capture_threads.items():
while capture_thread.get_latest()[0] is None:
time.sleep(0.01)
logger.info(f"Camera {name} ready (capture + encode in background)")
try:
while True:
@@ -80,10 +144,12 @@ class ImageServer:
# Build message
message = {"timestamps": {}, "images": {}}
for name, cam in self.cameras.items():
frame = cam.read() # Returns RGB
message["timestamps"][name] = time.time()
message["images"][name] = encode_image(frame)
for name, capture_thread in self.capture_threads.items():
encoded, timestamp = capture_thread.get_latest()
if encoded is not None and timestamp > last_published_ts.get(name, 0.0):
message["timestamps"][name] = timestamp
message["images"][name] = encoded
last_published_ts[name] = timestamp
# Send as JSON string (suppress if buffer full)
with contextlib.suppress(zmq.Again):
@@ -102,6 +168,8 @@ class ImageServer:
except KeyboardInterrupt:
pass
finally:
for capture_thread in self.capture_threads.values():
capture_thread.stop()
for cam in self.cameras.values():
cam.disconnect()
self.socket.close()
+12 -1
View File
@@ -27,7 +27,7 @@ class DatasetConfig:
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
@@ -36,6 +36,16 @@ class DatasetConfig:
video_backend: str = field(default_factory=get_safe_default_codec)
streaming: bool = False
def __post_init__(self) -> None:
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
)
if len(self.episodes) != len(set(self.episodes)):
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
@dataclass
class WandBConfig:
@@ -47,6 +57,7 @@ class WandBConfig:
notes: str | None = None
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
add_tags: bool = True # If True, save configuration as tags in the WandB run.
@dataclass
+1 -1
View File
@@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
+3
View File
@@ -50,6 +50,9 @@ class TrainPipelineConfig(HubMixin):
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Set to True to use deterministic cuDNN algorithms for reproducibility.
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
cudnn_deterministic: bool = False
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
@@ -746,7 +746,8 @@ def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
"""Save annotations to LeRobot dataset parquet format."""
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
from lerobot.datasets.io_utils import load_episodes
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
@@ -840,7 +841,7 @@ def generate_auto_sparse_annotations(
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
"""Load annotations from LeRobot dataset parquet files."""
from lerobot.datasets.utils import load_episodes
from lerobot.datasets.io_utils import load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
+13 -9
View File
@@ -24,7 +24,16 @@ import pandas as pd
import tqdm
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -32,14 +41,7 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
@@ -289,7 +291,9 @@ def aggregate_datasets(
logging.info("Find all tasks")
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
dst_meta.tasks = pd.DataFrame(
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
)
meta_idx = {"chunk": 0, "file": 0}
data_idx = {"chunk": 0, "file": 0}
@@ -1,56 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import packaging.version
V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v3.0 which is not backward compatible with v2.1.
Please, update your dataset to the new format using this command:
```
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
```
If you already have a converted version uploaded to the hub, then this error might be because of
an older version in your local cache. Consider deleting the cached version and retrying.
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
if version.major == 2 and version.minor == 1:
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
else:
raise NotImplementedError(
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
+7
View File
@@ -7,6 +7,13 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+1 -1
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
from lerobot.datasets.io_utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
+517
View File
@@ -0,0 +1,517 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import numpy as np
import packaging.version
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
from lerobot.datasets.io_utils import (
get_file_size_in_mb,
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_json,
write_stats,
write_tasks,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
INFO_PATH,
check_version_compatibility,
flatten_dict,
get_safe_version,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
if writer is not None:
writer.close()
self.writer = None
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
@property
def url_root(self) -> str:
return f"hf://datasets/{self.repo_id}"
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
file_idx = ep[f"videos/{vid_key}/file_index"]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
return self.info["robot_type"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
"""All features contained in the dataset."""
return self.info["features"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def chunks_size(self) -> int:
"""Max number of files per chunk."""
return self.info["chunks_size"]
@property
def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"]
@property
def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
if task in self.tasks.index:
return int(self.tasks.loc[task].task_index)
else:
return None
def save_episode_tasks(self, tasks: list[str]):
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode_metadata(self, episode_dict: dict) -> None:
"""Buffer episode metadata and write to parquet in batches for efficiency.
This function accumulates episode metadata in a buffer and flushes it when the buffer
reaches the configured size. This reduces I/O overhead by writing multiple episodes
at once instead of one row at a time.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert to list format for each value
episode_dict = {key: [value] for key, value in episode_dict.items()}
num_frames = episode_dict["length"][0]
if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self.episodes is not None and len(self.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
file_idx = self.episodes[-1]["meta/episodes/file_index"]
latest_num_frames = self.episodes[-1]["dataset_to_index"]
episode_dict["dataset_from_index"] = [latest_num_frames]
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
else:
episode_dict["dataset_from_index"] = [0]
episode_dict["dataset_to_index"] = [num_frames]
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
else:
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
file_idx = self.latest_episode["meta/episodes/file_index"][0]
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
)
if Path(latest_path).exists():
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
latest_num_frames = self.latest_episode["episode_index"][0]
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
# Size limit is reached, flush buffer and prepare new parquet file
self._flush_metadata_buffer()
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
self._close_writer()
# Update the existing pandas dataframe with new row
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
"length": episode_length,
}
episode_dict.update(episode_metadata)
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["total_tasks"] = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
This allows users to customize storage organization without modifying the constructor.
These settings control how episodes are chunked and how large files can grow before
creating new ones.
Args:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size
if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb
if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
def get_chunk_settings(self) -> dict[str, int]:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
}
def __repr__(self):
feature_keys = list(self.features)
return (
f"{self.__class__.__name__}({{\n"
f" Repository ID: '{self.repo_id}',\n"
f" Total episodes: '{self.total_episodes}',\n"
f" Total frames: '{self.total_frames}',\n"
f" Features: '{feature_keys}',\n"
"})',\n"
)
@classmethod
def create(
cls,
repo_id: str,
fps: int,
features: dict,
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
metadata_buffer_size: int = 10,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
CODEBASE_VERSION,
fps,
features,
use_videos,
robot_type,
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError(
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
"Either remove video features from the features dict, or set 'use_videos=True'."
)
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
return obj
+52 -40
View File
@@ -38,19 +38,22 @@ from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import (
get_parquet_file_size_in_mb,
load_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
@@ -89,8 +92,8 @@ def delete_episodes(
Args:
dataset: The source LeRobotDataset.
episode_indices: List of episode indices to delete.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
"""
if not episode_indices:
raise ValueError("No episodes to delete")
@@ -152,7 +155,7 @@ def split_dataset(
dataset: The source LeRobotDataset to split.
splits: Either a dict mapping split names to episode indices, or a dict mapping
split names to fractions (must sum to <= 1.0).
output_dir: Base directory for output datasets. If None, uses default location.
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
Examples:
Split by specific episodes
@@ -243,8 +246,8 @@ def merge_datasets(
Args:
datasets: List of LeRobotDatasets to merge.
output_repo_id: Repository ID for the merged dataset.
output_dir: Directory to save the merged dataset. If None, uses default location.
output_repo_id: Merged dataset identifier.
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
"""
if not datasets:
raise ValueError("No datasets to merge")
@@ -288,8 +291,8 @@ def modify_features(
dataset: The source LeRobotDataset.
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
remove_features: Optional feature name(s) to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features modified.
@@ -390,8 +393,8 @@ def add_features(
Args:
dataset: The source LeRobotDataset.
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with all features added.
@@ -427,8 +430,8 @@ def remove_feature(
Args:
dataset: The source LeRobotDataset.
feature_names: Name(s) of features to remove. Can be a single string or list.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
Returns:
New dataset with features removed.
@@ -567,20 +570,22 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[float, float]],
episodes_to_keep: list[tuple[int, int]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified time ranges and re-encodes them with
This function decodes frames from specified frame ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -622,9 +627,10 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
time_ranges = sorted(episodes_to_keep)
frame_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -634,21 +640,20 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Check if frame is in any of our desired frame ranges.
# Skip ranges that have already passed.
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(time_ranges):
if range_idx >= len(frame_ranges):
break
# Check if frame is in current range.
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -661,6 +666,7 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -749,15 +755,17 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of time ranges to keep, in sorted order.
# Build list of frame ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[float, float]] = []
episodes_to_keep_ranges: list[tuple[int, int]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -910,7 +918,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
"""
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import embed_images
hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
@@ -1470,7 +1479,9 @@ def modify_tasks(
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
new_task_df = pd.DataFrame(
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
@@ -1524,7 +1535,7 @@ def modify_tasks(
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
output_dir: Path | None = None,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
@@ -1543,8 +1554,8 @@ def convert_image_to_video_dataset(
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
@@ -1595,6 +1606,7 @@ def convert_image_to_video_dataset(
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
+3 -5
View File
@@ -20,11 +20,9 @@ import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
+552
View File
@@ -0,0 +1,552 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pformat
from typing import Any
import datasets
import numpy as np
from PIL import Image as PILImage
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
def get_hf_features_from_features(features: dict) -> datasets.Features:
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
Args:
features (dict): A LeRobot-style feature dictionary.
Returns:
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
Raises:
ValueError: If a feature has an unsupported shape.
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {
key: ftype
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def combine_feature_dicts(*dicts: dict) -> dict:
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,
features: dict,
use_videos: bool,
robot_type: str | None = None,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
Args:
codebase_version (str): The version of the LeRobot codebase.
fps (int): The frames per second of the data.
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
Returns:
dict: A dictionary with the initial dataset metadata.
"""
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
This ensures that adding these delta timestamps to any existing timestamp in
the dataset will result in a value that aligns with the dataset's frame rate.
Args:
delta_timestamps (dict): A dictionary where values are lists of time
deltas in seconds.
fps (int): The frames per second of the dataset.
tolerance_s (float): The allowed tolerance in seconds.
raise_value_error (bool): If True, raises an error on failure.
Returns:
bool: True if all deltas are valid, False otherwise.
Raises:
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
"""Convert delta timestamps in seconds to delta indices in frames.
Args:
delta_timestamps (dict): A dictionary of time deltas in seconds.
fps (int): The frames per second of the dataset.
Returns:
dict: A dictionary of frame delta indices.
"""
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
return delta_indices
def validate_frame(frame: dict, features: dict) -> None:
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
# task is a special required field that's not part of regular features
if "task" not in actual_features:
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
# Remove task from actual_features for regular feature validation
actual_features_for_validation = actual_features - {"task"}
error_message = validate_features_presence(actual_features_for_validation, expected_features)
common_features = actual_features_for_validation & expected_features
for name in common_features:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
"""Check for missing or extra features in a frame.
Args:
actual_features (set[str]): The set of feature names present in the frame.
expected_features (set[str]): The set of feature names expected in the frame.
Returns:
str: An error message string if there's a mismatch, otherwise an empty string.
"""
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
) -> str:
"""Validate the dtype and shape of a single feature's value.
Args:
name (str): The name of the feature.
feature (dict): The feature specification from the LeRobot features dictionary.
value: The value of the feature to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
Raises:
NotImplementedError: If the feature dtype is not supported for validation.
"""
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
) -> str:
"""Validate a feature that is expected to be a numpy array.
Args:
name (str): The name of the feature.
expected_dtype (str): The expected numpy dtype as a string.
expected_shape (list[int]): The expected shape.
value (np.ndarray): The numpy array to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
) -> str:
"""Validate a feature that is expected to be an image or video frame.
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str) -> str:
"""Validate a feature that is expected to be a string.
Args:
name (str): The name of the feature.
value (str): The value to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
"""Validate the episode buffer before it's written to disk.
Ensures the buffer has the required keys, contains at least one frame, and
has features consistent with the dataset's specification.
Args:
episode_buffer (dict): The buffer containing data for a single episode.
total_episodes (int): The current total number of episodes in the dataset.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the buffer is invalid.
NotImplementedError: If the episode index is manually set and doesn't match.
"""
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
+6 -4
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import multiprocessing
import queue
import threading
@@ -22,6 +23,8 @@ import numpy as np
import PIL.Image
import torch
logger = logging.getLogger(__name__)
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
@@ -31,7 +34,7 @@ def safe_stop_image_writer(func):
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
logger.warning("Waiting for image writer to terminate...")
image_writer.stop()
raise e
@@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
PIL.Image.Image object.
Side Effects:
Prints an error message to the console if the image writing process
fails for any reason.
Logs an error message if the image writing process fails for any reason.
"""
try:
if isinstance(image, np.ndarray):
@@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
logger.error("Error writing image %s: %s", fpath, e)
def worker_thread_loop(queue: queue.Queue):
+342
View File
@@ -0,0 +1,342 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Any
import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from datasets.table import embed_table_storage
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.datasets.utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
STATS_PATH,
flatten_dict,
serialize_dict,
unflatten_dict,
)
from lerobot.utils.utils import SuppressProgressBars
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
total_uncompressed_size = 0
for row_group in range(metadata.num_row_groups):
rg_metadata = metadata.row_group(row_group)
for column in range(rg_metadata.num_columns):
col_metadata = rg_metadata.column(column)
total_uncompressed_size += col_metadata.total_uncompressed_size
return total_uncompressed_size / (1024**2)
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2)
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
metadata = pq.read_metadata(parquet_path)
return metadata.num_rows
def get_file_size_in_mb(file_path: Path) -> float:
"""Get file size on disk in megabytes.
Args:
file_path (Path): Path to the file.
"""
file_size_bytes = file_path.stat().st_size
return file_size_bytes / (1024**2)
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
"""Embed image bytes into the dataset table before saving to Parquet.
This function prepares a Hugging Face dataset for serialization by converting
image objects into an embedded format that can be stored in Arrow/Parquet.
Args:
dataset (datasets.Dataset): The input dataset, possibly containing image features.
Returns:
datasets.Dataset: The dataset with images embedded in the table storage.
"""
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
return dataset
def load_json(fpath: Path) -> Any:
"""Load data from a JSON file.
Args:
fpath (Path): Path to the JSON file.
Returns:
Any: The data loaded from the JSON file.
"""
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
"""Write data to a JSON file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to write.
fpath (Path): The path to the output JSON file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
"""Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: The dataset information dictionary.
"""
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path) -> None:
"""Serialize and write dataset statistics to their standard file path.
Args:
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
local_dir (Path): The root directory of the dataset.
"""
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
Args:
stats (dict): The statistics dictionary.
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
"""Load dataset statistics and cast numerical values to numpy arrays.
Returns None if the stats file doesn't exist.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A dictionary of statistics or None if the file is not found.
"""
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
path = local_dir / DEFAULT_TASKS_PATH
path.parent.mkdir(parents=True, exist_ok=True)
tasks.to_parquet(path)
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
tasks.index.name = "task"
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
Used primarily during dataset conversion (v2.1 v3.0) and in test fixtures.
Args:
episodes: HuggingFace Dataset containing episode metadata
local_dir: Root directory where the dataset will be stored
"""
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
raise NotImplementedError(
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
"This function only supports single-file episode metadata. "
)
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes.to_parquet(fpath)
def load_episodes(local_dir: Path) -> datasets.Dataset:
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
# Select episode features/columns containing references to episode data and videos
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
# This is to speedup access to these data, instead of having to load episode stats.
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
return episodes
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
"""Load an image from a file into a numpy array.
Args:
fpath (str | Path): Path to the image file.
dtype (np.dtype): The desired data type of the output array. If floating,
pixels are scaled to [0, 1].
channel_first (bool): If True, converts the image to (C, H, W) format.
Otherwise, it remains in (H, W, C) format.
Returns:
np.ndarray: The image as a numpy array.
"""
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict:
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
This function is used to convert an item from a streaming dataset to PyTorch tensors.
Args:
item (dict): Dictionary of items from a dataset.
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
+32 -687
View File
@@ -23,526 +23,52 @@ from pathlib import Path
import datasets
import numpy as np
import packaging.version
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
INFO_PATH,
_validate_feature_names,
from lerobot.datasets.compute_stats import compute_episode_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
check_version_compatibility,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
flatten_dict,
get_delta_indices,
get_file_size_in_mb,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.io_utils import (
embed_images,
get_file_size_in_mb,
hf_transform_to_torch,
load_episodes,
load_nested_dataset,
write_info,
write_json,
write_stats,
write_tasks,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
create_lerobot_dataset_card,
get_safe_version,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
encode_video_frames,
get_safe_default_codec,
get_video_duration_in_s,
get_video_info,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
if writer is not None:
writer.close()
self.writer = None
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
@property
def url_root(self) -> str:
return f"hf://datasets/{self.repo_id}"
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
file_idx = ep[f"videos/{vid_key}/file_index"]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
return self.info["robot_type"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
"""All features contained in the dataset."""
return self.info["features"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def chunks_size(self) -> int:
"""Max number of files per chunk."""
return self.info["chunks_size"]
@property
def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"]
@property
def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
if task in self.tasks.index:
return int(self.tasks.loc[task].task_index)
else:
return None
def save_episode_tasks(self, tasks: list[str]):
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode_metadata(self, episode_dict: dict) -> None:
"""Buffer episode metadata and write to parquet in batches for efficiency.
This function accumulates episode metadata in a buffer and flushes it when the buffer
reaches the configured size. This reduces I/O overhead by writing multiple episodes
at once instead of one row at a time.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert to list format for each value
episode_dict = {key: [value] for key, value in episode_dict.items()}
num_frames = episode_dict["length"][0]
if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self.episodes is not None and len(self.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
file_idx = self.episodes[-1]["meta/episodes/file_index"]
latest_num_frames = self.episodes[-1]["dataset_to_index"]
episode_dict["dataset_from_index"] = [latest_num_frames]
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
else:
episode_dict["dataset_from_index"] = [0]
episode_dict["dataset_to_index"] = [num_frames]
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
else:
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
file_idx = self.latest_episode["meta/episodes/file_index"][0]
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
)
if Path(latest_path).exists():
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
latest_num_frames = self.latest_episode["episode_index"][0]
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
# Size limit is reached, flush buffer and prepare new parquet file
self._flush_metadata_buffer()
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
self._close_writer()
# Update the existing pandas dataframe with new row
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
"length": episode_length,
}
episode_dict.update(episode_metadata)
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["total_tasks"] = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
This allows users to customize storage organization without modifying the constructor.
These settings control how episodes are chunked and how large files can grow before
creating new ones.
Args:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size
if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb
if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
def get_chunk_settings(self) -> dict[str, int]:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
}
def __repr__(self):
feature_keys = list(self.features)
return (
f"{self.__class__.__name__}({{\n"
f" Repository ID: '{self.repo_id}',\n"
f" Total episodes: '{self.total_episodes}',\n"
f" Total frames: '{self.total_frames}',\n"
f" Features: '{feature_keys}',\n"
"})',\n"
)
@classmethod
def create(
cls,
repo_id: str,
fps: int,
features: dict,
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
metadata_buffer_size: int = 10,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
CODEBASE_VERSION,
fps,
features,
use_videos,
robot_type,
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
return obj
logger = logging.getLogger(__name__)
def _encode_video_worker(
@@ -596,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
the dataset from that address and load it, pending your dataset is compliant with
codebase_version v3.0. If your dataset has been created before this new format, you will be
prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at
lerobot/datasets/v30/convert_dataset_v21_to_v30.py.
lerobot/scripts/convert_dataset_v21_to_v30.py.
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
@@ -664,11 +190,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
for the README).
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
will be stored under root/repo_id.
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
'~/.cache/huggingface/lerobot'.
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory where the dataset will be downloaded and
stored. If set, all dataset files will be stored directly under this path. If not set, the
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
HF_LEROBOT_HOME environment variable).
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
@@ -747,7 +273,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -839,7 +365,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
@@ -1326,7 +852,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logging.error(f"Video encoding failed for {video_key}: {exc}")
logger.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self.meta.video_keys:
@@ -1365,7 +891,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if end_episode is None:
end_episode = self.num_episodes
logging.info(
logger.info(
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
@@ -1375,7 +901,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding videos for episode {ep_idx}")
logger.info(f"Encoding videos for episode {ep_idx}")
if (
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
@@ -1605,7 +1131,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
logger.warning(
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
@@ -1683,7 +1209,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
obj.episodes = None
@@ -1717,183 +1242,3 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._streaming_encoder = None
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Number of Samples: {self.num_frames},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
+210
View File
@@ -0,0 +1,210 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Callable
from pathlib import Path
import datasets
import torch
import torch.utils
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import VideoFrame
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logger.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Number of Samples: {self.num_frames},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
-382
View File
@@ -1,382 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An online buffer for the online training loop in train.py
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
supports in-place slicing and mutation which is very handy for a dynamic buffer.
"""
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def _make_memmap_safe(**kwargs) -> np.memmap:
"""Make a numpy memmap with checks on available disk space first.
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
For information on dtypes:
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
"""
if kwargs["mode"].startswith("w"):
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
stats = os.statvfs(Path(kwargs["filename"]).parent)
available_space = stats.f_bavail * stats.f_frsize # bytes
if required_space >= available_space * 0.8:
raise RuntimeError(
f"You're about to take up {required_space} of {available_space} bytes available."
)
return np.memmap(**kwargs)
class OnlineBuffer(torch.utils.data.Dataset):
"""FIFO data buffer for the online training loop in train.py.
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
loop in the same way that a LeRobotDataset would be used.
The underlying data structure will have data inserted in a circular fashion. Always insert after the
last index, and when you reach the end, wrap around to the start.
The data is stored in a numpy memmap.
"""
NEXT_INDEX_KEY = "_next_index"
OCCUPANCY_MASK_KEY = "_occupancy_mask"
INDEX_KEY = "index"
FRAME_INDEX_KEY = "frame_index"
EPISODE_INDEX_KEY = "episode_index"
TIMESTAMP_KEY = "timestamp"
IS_PAD_POSTFIX = "_is_pad"
def __init__(
self,
write_dir: str | Path,
data_spec: dict[str, Any] | None,
buffer_capacity: int | None,
fps: float | None = None,
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
):
"""
The online buffer can be provided from scratch or you can load an existing online buffer by passing
a `write_dir` associated with an existing buffer.
Args:
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
Note that if the files already exist, they are opened in read-write mode (used for training
resumption.)
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
but note that "index", "frame_index" and "episode_index" are already accounted for by this
class, so you don't need to include them.
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
system's available disk space when choosing this.
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
converted to dict[str, np.ndarray] for optimization purposes.
"""
self.set_delta_timestamps(delta_timestamps)
self._fps = fps
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
# the requested frames. It is only used when `delta_timestamps` is provided.
# minus 1e-4 to account for possible numerical error
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
self._buffer_capacity = buffer_capacity
data_spec = self._make_data_spec(data_spec, buffer_capacity)
Path(write_dir).mkdir(parents=True, exist_ok=True)
self._data = {}
for k, v in data_spec.items():
self._data[k] = _make_memmap_safe(
filename=Path(write_dir) / k,
dtype=v["dtype"] if v is not None else None,
mode="r+" if (Path(write_dir) / k).exists() else "w+",
shape=tuple(v["shape"]) if v is not None else None,
)
@property
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
return self._delta_timestamps
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
"""Set delta_timestamps converting the values to numpy arrays.
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
need to be converted into numpy arrays.
"""
if value is not None:
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
)
preset_keys = {
OnlineBuffer.INDEX_KEY,
OnlineBuffer.FRAME_INDEX_KEY,
OnlineBuffer.EPISODE_INDEX_KEY,
OnlineBuffer.TIMESTAMP_KEY,
}
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
raise ValueError(
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
f"The provided data_spec has {intersection}."
)
complete_data_spec = {
# _next_index will be a pointer to the next index that we should start filling from when we add
# more data.
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
"""Add new data to the buffer, which could potentially mean shifting old data out.
The new data should contain all the frames (in order) of any number of episodes. The indices should
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
will be done in place!
"""
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
raise ValueError(f"Missing data keys: {missing_keys}")
new_data_length = len(data[self.data_keys[0]])
if not all(len(data[k]) == new_data_length for k in self.data_keys):
raise ValueError("All data items should have the same length")
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
# Sanity check to make sure that the new data indices start from 0.
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
for k in self.data_keys:
if n_surplus == 0:
slc = slice(next_index, next_index + new_data_length)
self._data[k][slc] = data[k]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
else:
self._data[k][next_index:] = data[k][:-n_surplus]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
self._data[k][:n_surplus] = data[k][-n_surplus:]
if n_surplus == 0:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
else:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
@property
def data_keys(self) -> list[str]:
keys = set(self._data)
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
return sorted(keys)
@property
def fps(self) -> float | None:
return self._fps
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
)
@property
def num_frames(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_frames
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}
for k, v in item.items():
if isinstance(v, torch.Tensor):
item_[k] = v
elif isinstance(v, np.ndarray):
item_[k] = torch.from_numpy(v)
else:
item_[k] = torch.tensor(v)
return item_
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self) or idx < -len(self):
raise IndexError
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
if self.delta_timestamps is None:
return self._item_to_tensors(item)
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
episode_data_indices = np.where(
np.bitwise_and(
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
# Get timestamps used as query to retrieve data of previous/future frames.
query_ts = current_ts + self.delta_timestamps[data_key]
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
# the episode.
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
argmin_ = np.argmin(dist, axis=1)
min_ = dist[np.arange(dist.shape[0]), argmin_]
is_pad = min_ > self.tolerance_s
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
)
# Load frames for this data key.
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
return self._item_to_tensors(item)
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
def compute_sampler_weights(
offline_dataset: LeRobotDataset,
offline_drop_n_last_frames: int = 0,
online_dataset: OnlineBuffer | None = None,
online_sampling_ratio: float | None = None,
online_drop_n_last_frames: int = 0,
) -> torch.Tensor:
"""Compute the sampling weights for the online training dataloader in train.py.
Args:
offline_dataset: The LeRobotDataset used for offline pre-training.
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
online_dataset: The OnlineBuffer used in online training.
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
online dataset is provided, this value must also be provided.
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
dataset.
Returns:
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
Notes to maintainers:
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
is the ability to turn shuffling off.
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
weights = []
if len(offline_dataset) > 0:
offline_data_mask_indices = []
for start_index, end_index in zip(
offline_dataset.meta.episodes["dataset_from_index"],
offline_dataset.meta.episodes["dataset_to_index"],
strict=True,
):
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(offline_dataset),),
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
)
* offline_data_mask
)
if online_dataset is not None and len(online_dataset) > 0:
online_data_mask_indices = []
episode_indices = online_dataset.get_data_by_key("episode_index")
for episode_idx in torch.unique(episode_indices):
where_episode = torch.where(episode_indices == episode_idx)
start_index = where_episode[0][0]
end_index = where_episode[0][-1] + 1
online_data_mask_indices.extend(
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
)
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
online_data_mask[torch.tensor(online_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(online_dataset),),
fill_value=online_sampling_ratio / online_data_mask.sum(),
)
* online_data_mask
)
weights = torch.cat(weights)
if weights.sum() == 0:
weights += 1 / len(weights)
else:
weights /= weights.sum()
return weights
+9 -6
View File
@@ -17,8 +17,9 @@ from collections.abc import Sequence
from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
@@ -43,11 +44,11 @@ def create_initial_features(
return features
# Helper to filter state/action keys based on regex patterns.
def should_keep(key: str, patterns: tuple[str]) -> bool:
# Helper to filter state/action keys based on compiled regex patterns.
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
return any(pat.search(key) for pat in patterns)
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
@@ -88,6 +89,8 @@ def aggregate_pipeline_dataset_features(
Returns:
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
"""
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
all_features = pipeline.transform_features(initial_features)
# Intermediate storage for categorized and filtered features.
@@ -119,7 +122,7 @@ def aggregate_pipeline_dataset_features(
# 2. Apply filtering rules.
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, patterns):
if not is_image and not should_keep(key, compiled_patterns):
continue
# 3. Add the feature to the appropriate group with a clean name.
@@ -1,73 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datasets
import torch
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
+25
View File
@@ -13,10 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Iterator
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
def __init__(
@@ -39,13 +42,35 @@ class EpisodeAwareSampler:
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self.shuffle = shuffle
+163 -7
View File
@@ -13,7 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Generator, Iterator
from collections import deque
from collections.abc import Callable, Generator, Iterable, Iterator
from pathlib import Path
import datasets
@@ -21,16 +22,13 @@ import numpy as np
import torch
from datasets import load_dataset
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_delta_indices
from lerobot.datasets.io_utils import item_to_torch
from lerobot.datasets.utils import (
Backtrackable,
LookAheadError,
LookBackError,
check_version_compatibility,
find_float_index,
get_delta_indices,
is_float_in_list,
item_to_torch,
safe_shard,
)
from lerobot.datasets.video_utils import (
@@ -40,6 +38,164 @@ from lerobot.datasets.video_utils import (
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
class LookBackError(Exception):
"""
Exception raised when trying to look back in the history of a Backtrackable object.
"""
pass
class LookAheadError(Exception):
"""
Exception raised when trying to look ahead in the future of a Backtrackable object.
"""
pass
class Backtrackable[T]:
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
This is useful for streaming datasets where you need to access previous and future items
but can't load the entire dataset into memory.
Example:
-------
```python
ds = load_dataset("c4", "en", streaming=True, split="train")
rev = Backtrackable(ds, history=3, lookahead=2)
x0 = next(rev) # forward
x1 = next(rev)
x2 = next(rev)
# Look ahead
x3_peek = rev.peek_ahead(1) # next item without moving cursor
x4_peek = rev.peek_ahead(2) # two items ahead
# Look back
x1_again = rev.peek_back(1) # previous item without moving cursor
x0_again = rev.peek_back(2) # two items back
# Move backward
x1_back = rev.prev() # back one step
next(rev) # returns x2, continues forward from where we were
```
"""
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
if history < 1:
raise ValueError("history must be >= 1")
if lookahead <= 0:
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
def __iter__(self) -> "Backtrackable[T]":
return self
def __next__(self) -> T:
# If we've stepped back, consume from back buffer first
if self._cursor < 0: # -1 means "last item", etc.
self._cursor += 1
return self._back_buf[self._cursor]
# If we have items in the ahead buffer, use them first
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
# Add current item to back buffer and reset cursor
self._back_buf.append(item)
self._cursor = 0
return item
def prev(self) -> T:
"""
Step one item back in history and return it.
Raises IndexError if already at the oldest buffered item.
"""
if len(self._back_buf) + self._cursor <= 1:
raise LookBackError("At start of history")
self._cursor -= 1
return self._back_buf[self._cursor]
def peek_back(self, n: int = 1) -> T:
"""
Look `n` items back (n=1 == previous item) without moving the cursor.
"""
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
raise LookBackError("peek_back distance out of range")
return self._back_buf[self._cursor - (n + 1)]
def peek_ahead(self, n: int = 1) -> T:
"""
Look `n` items ahead (n=1 == next item) without moving the cursor.
Fills the ahead buffer if necessary.
"""
if n < 1:
raise LookAheadError("peek_ahead distance must be 1 or more")
elif n > self._lookahead:
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
# Fill ahead buffer if we don't have enough items
while len(self._ahead_buf) < n:
try:
item = next(self._source)
self._ahead_buf.append(item)
except StopIteration as err:
raise LookAheadError("peek_ahead: not enough items in source") from err
return self._ahead_buf[n - 1]
def history(self) -> list[T]:
"""
Return a copy of the buffered history (most recent last).
The list length `history` argument passed at construction.
"""
if self._cursor == 0:
return list(self._back_buf)
# When cursor<0, slice so the order remains chronological
return list(self._back_buf)[: self._cursor or None]
def can_peek_back(self, steps: int = 1) -> bool:
"""
Check if we can go back `steps` items without raising an IndexError.
"""
return steps <= len(self._back_buf) + self._cursor
def can_peek_ahead(self, steps: int = 1) -> bool:
"""
Check if we can peek ahead `steps` items.
This may involve trying to fill the ahead buffer.
"""
if self._lookahead > 0 and steps > self._lookahead:
return False
# Try to fill ahead buffer to check if we can peek that far
try:
while len(self._ahead_buf) < steps:
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
return False
item = next(self._source)
self._ahead_buf.append(item)
return True
except StopIteration:
return False
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""LeRobotDataset with streaming capabilities.
File diff suppressed because it is too large Load Diff
+51 -43
View File
@@ -37,6 +37,8 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
@@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
pass # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
return available
@@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str:
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logging.info(f"Using video codec: {vcodec}")
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logging.info(f"Auto-selected video codec: {encoder}")
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
@@ -118,7 +120,7 @@ def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
logger.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
@@ -208,7 +210,7 @@ def decode_video_frames_torchvision(
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
@@ -227,28 +229,33 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
logger.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames)
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
return closest_frames
@@ -343,7 +350,7 @@ def decode_video_frames_torchcodec(
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
logger.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
@@ -353,22 +360,23 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
logger.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
@@ -402,14 +410,14 @@ def encode_video_frames(
imgs_dir = Path(imgs_dir)
if video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
return
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logging.warning(
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
@@ -502,7 +510,7 @@ def concatenate_video_files(
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
@@ -687,7 +695,7 @@ class _CameraEncoderThread(threading.Thread):
self.result_queue.put(("ok", None))
except Exception as e:
logging.error(f"Encoder thread error: {e}")
logger.error(f"Encoder thread error: {e}")
if container is not None:
with contextlib.suppress(Exception):
container.close()
@@ -813,7 +821,7 @@ class StreamingVideoEncoder:
count = self._dropped_frames[video_key]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0:
logging.warning(
logger.warning(
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
)
@@ -835,7 +843,7 @@ class StreamingVideoEncoder:
# Report dropped frames
for video_key, count in self._dropped_frames.items():
if count > 0:
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
# Send sentinel to all queues
for video_key in self._frame_queues:
@@ -845,7 +853,7 @@ class StreamingVideoEncoder:
for video_key in self._threads:
self._threads[video_key].join(timeout=120)
if self._threads[video_key].is_alive():
logging.error(f"Encoder thread for {video_key} did not finish in time")
logger.error(f"Encoder thread for {video_key} did not finish in time")
self._stop_events[video_key].set()
self._threads[video_key].join(timeout=5)
results[video_key] = (self._video_paths[video_key], None)
@@ -857,7 +865,7 @@ class StreamingVideoEncoder:
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
results[video_key] = (self._video_paths[video_key], data)
except queue.Empty:
logging.error(f"No result from encoder thread for {video_key}")
logger.error(f"No result from encoder thread for {video_key}")
results[video_key] = (self._video_paths[video_key], None)
self._cleanup()
@@ -1065,13 +1073,13 @@ class VideoEncodingManager:
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
logger.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
logger.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
logger.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
@@ -1088,7 +1096,7 @@ class VideoEncodingManager:
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
logger.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
@@ -1099,8 +1107,8 @@ class VideoEncodingManager:
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
logger.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
+1 -1
View File
@@ -29,7 +29,7 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
from lerobot.types import RobotObservation
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
+1 -1
View File
@@ -25,7 +25,7 @@ import metaworld.policies as policies
import numpy as np
from gymnasium import spaces
from lerobot.processor import RobotObservation
from lerobot.types import RobotObservation
# ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
+1 -1
View File
@@ -29,7 +29,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.processor import RobotObservation
from lerobot.types import RobotObservation
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
+4 -4
View File
@@ -29,7 +29,7 @@ from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
from typing import Protocol, TypeAlias
from typing import Protocol
import serial
from deepdiff import DeepDiff
@@ -38,8 +38,8 @@ from tqdm import tqdm
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
type NameOrID = str | int
type Value = int | float
logger = logging.getLogger(__name__)
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
# Backward compatibility alias
MotorsBus: TypeAlias = SerialMotorsBus
MotorsBus = SerialMotorsBus
+2 -1
View File
@@ -23,7 +23,8 @@ import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.datasets.io_utils import write_json
from lerobot.datasets.utils import flatten_dict, unflatten_dict
from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
+1 -1
View File
@@ -23,7 +23,7 @@ import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.datasets.utils import write_json
from lerobot.datasets.io_utils import write_json
from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.io_utils import deserialize_json_into_object
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
backbone. If None, no resizing is done and the original image resolution is used.
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}."
)
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
if self.resize_shape is None and self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for `{key}`."
)
# Check that all input images have the same shape.
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if config.crop_shape is not None:
if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
Returns:
(B, D) image feature.
"""
# Preprocess: maybe crop (if it was set up in the __init__).
# Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
+5 -5
View File
@@ -18,15 +18,14 @@ from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
from typing import Any, TypedDict, Unpack
import torch
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
@@ -44,13 +43,14 @@ from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
transition_to_policy_action,
)
from lerobot.types import PolicyAction
from lerobot.utils.constants import (
ACTION,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
@@ -4,10 +4,9 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _resize_for_patching(
self,
image: "torch.Tensor",
image: torch.Tensor,
target_resolution: tuple,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
input_data_format: ChannelDimension,
) -> "torch.Tensor":
) -> torch.Tensor:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _get_image_patches(
self,
image: "torch.Tensor",
image: torch.Tensor,
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
interpolation: F.InterpolationMode,
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
) -> list[torch.Tensor]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
def _preprocess(
self,
images: list["torch.Tensor"],
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
interpolation: F.InterpolationMode | None,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
@@ -49,7 +49,7 @@ from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
HF_LEROBOT_HOME,
+1 -2
View File
@@ -20,12 +20,11 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
+5 -5
View File
@@ -20,12 +20,11 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -99,10 +98,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
return dist.sample((bsize,)).to(device)
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
+1 -1
View File
@@ -36,7 +36,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,

Some files were not shown because too many files have changed in this diff Show More