Compare commits

..

14 Commits

Author SHA1 Message Date
Cheng Yin 9db9c35cb4 fix(config): add lora_alpha to PeftConfig (#3573)
* fix(config): add lora_alpha to PeftConfig

PeftConfig was missing the lora_alpha field, causing the PEFT library
to default to alpha=8 regardless of the LoRA rank, which dampens the
adaptation signal for high-rank adapters (e.g., r=128).

This adds lora_alpha: int | None = None to PeftConfig, allowing users
to specify --peft.lora_alpha <value> on the CLI.

Closes #3551

* fix(docs): add lora_alpha to peft training example + clarify scaling formula

- Add --peft.lora_alpha=64 to docs/source/peft_training.mdx example to
  prevent new users from hitting the alpha=8 default dampening bug
- Clarify lora_alpha comment in default.py with scaling = lora_alpha / r

* docs: mention both --peft.r and --peft.lora_alpha in LoRA description

---------

Co-authored-by: Cheng Yin <yin@users.noreply.github.com>
2026-05-13 11:09:19 +02:00
Jash Shah fe96b28c74 Fix policy.path not working in YAML config files (#3145)
* fix(config): support policy.path in YAML config files

policy.path was only handled via CLI args (filtered from sys.argv before
draccus, then retrieved in validate()). When specified in YAML, draccus
would crash because 'path' is not a valid field on PreTrainedConfig.

Extract path fields from the YAML/JSON config before draccus processes
it, store them in a module-level dict, and fall back to it in
get_path_arg() when the CLI doesn't have the path.

Fixes #2957

* fix(parser): preserve YAML policy overrides when loading from pretrained

When policy.path is set in YAML, validate() was calling from_pretrained
with only CLI overrides, discarding any YAML policy fields (e.g. lr,
batch_size) that draccus had already parsed. Fix by capturing the
remaining YAML fields as CLI-style args in _config_yaml_overrides and
merging them into the overrides passed to from_pretrained in train.py,
eval.py, and lerobot_record.py (CLI args still take precedence).

Also fix the NamedTemporaryFile SIM115 ruff warning and add types-PyYAML
to the mypy pre-commit hook.

* fix(parser): serialize bool/None values correctly in YAML policy overrides

Bool values from YAML configs (e.g. push_to_hub: true) were passed as
Python "True"/"False" strings instead of lowercase "true"/"false" that
draccus expects. Also skip None values to avoid passing "None" strings.

* revert: remove types-PyYAML from .pre-commit-config.yaml

* chore: fix quality check caused by untyped YAML import

Co-authored-by: masato-ka <jp6uzv@gmail.com>
Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co>

---------

Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
Co-authored-by: masato-ka <jp6uzv@gmail.com>
2026-05-13 09:45:27 +02:00
Steven Palma 2438df1307 chore(dependencies): update uv.lock (#3561) 2026-05-12 21:20:26 +02:00
Caroline Pascal f218d5ab30 feat(episodes): adding support for metadata based episodes filtering (#3530)
* feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset

* test(tests): adding tests

* chore(format): formatting code

* feat(performance): improving implementation for better performances on big datasets

* chores(warning): improving warnings and errors for episodes filtering

* test(invalid key): adding test for invalid filtering key

* chore(format): formatting code
2026-05-12 20:44:11 +02:00
Steven Palma 04125492e4 fix(datasets): expand torchcodec platform coverage + rewrite pyav fallback for torchvision >0.26 (#3588)
* fix(deps): better versioning control for torchcodec

* refactor(video_utils): replace torchvision with pyav

* adding Torchcodec version to lerobot-info

* chore(benchmarks): delete video benchmark

---------

Co-authored-by: Maximellerbach <maxime.ellerbach@huggingface.co>
2026-05-12 16:59:11 +02:00
Khalil Meftah e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00
Steven Palma 26ff40ddd7 chore(deps): cap torch ceiling at <2.12, pin Linux wheels to cu128 (#3570)
* chore(deps): ceiling + cuda

* ci: bump cuda version docker image

* ci: add cpu wheel to release workflow

* chore(deps): update uv.lock

* docs: update installation with cuda note
2026-05-11 19:47:55 +02:00
Maxime Ellerbach 6d269b28c8 docs(omx): adding some examples and scripts (#3566)
* docs(omx): adding some examples and scripts

* cleaning up and reviewing the cli args

* adding __init__.py to example folder, adjusting the examples

* adding reference to pretrained act policy

* moving `.send_action` before `dataset.add_frame` for consistency

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adjusting docstring

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adressing hardcoded dataset fps

* removed init as it worked without

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-05-11 15:36:32 +02:00
Steven Palma b607c8458e docs: add policy & compute guide (#3534)
* docs(policy): contributing a policy guide

* docs(training): HW compute guide

* chore(docs): add to readme and index

* Apply suggestions from code review

Co-authored-by: Haoming Song <1847575517@qq.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(docs): slight improvements

* refactor(docs): consolidate add policy docs

* chore(style): fix pre-commit

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Haoming Song <1847575517@qq.com>
2026-05-11 15:19:12 +02:00
Jash Shah 9e83510c99 fix(datasets): close file handle on VideoDecoder init failure in cache (#3542)
If VideoDecoder() raises during initialization, the fsspec file handle
was leaked since it was opened via __enter__() but never closed on the
exception path. Now explicitly closes the handle before re-raising.
2026-05-10 17:30:37 +02:00
Anthony Shoumikhin 1f7b03f5f2 chore(deps): allow torch 2.11/2.12 and fix autocast deprecation (#3435)
* chore(deps): allow torch 2.11/2.12 and fix autocast deprecation

- Bump torch to >=2.7,<2.13 (was <2.11), torchvision to <0.28 (was <0.26),
  and torchcodec to <0.13 (was <0.11) to allow installs against the latest
  stable torch 2.11 and the upcoming 2.12 line.
- Replace removed torch.get_autocast_gpu_dtype() with torch.get_autocast_dtype("cuda")
  in Florence2 and Qwen2.5-VL-MoE FlashAttention paths (the former is removed in 2.11+).
- Refresh uv.lock for the new resolution (torch 2.11.0+cu130, torchvision 0.26.0+cu130,
  torchcodec 0.11.1, full CUDA 13 stack).

Verified locally with `uv sync --locked` from a clean .venv and the lerobot
test suite (pytest -n 8 --dist=loadfile --timeout=300). Failure set is
identical to the pre-bump baseline: 18 pre-existing failures
(test_sac_policy*, test_pi0_rtc*, test_pi05_rtc*, test_replay_buffer*),
0 new, 0 fixed.

AI assistance: this change was authored with Claude Code per AI_POLICY.md.

* fix(policies): use device-agnostic autocast dtype lookup

Pass query_states.device.type to torch.get_autocast_dtype() instead of
hardcoding 'cuda', so the cast matches the active autocast context when
running under CPU/MPS/XPU autocast.

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-10 13:05:35 +02:00
Steven Palma cb8edf17e6 chore(dependencies): update uv.lock (#3475) 2026-05-10 12:24:22 +02:00
Steven Palma 5699f6cbf4 chore(ci): disable auto-stale (#3550) 2026-05-10 11:49:31 +02:00
masato-ka 0e6114ac36 fix(train): restrict legacy RA-BC migration to JSON checkpoints only (#3490)
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only

_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
2026-05-08 20:27:01 +02:00
107 changed files with 6915 additions and 5949 deletions
+2 -1
View File
@@ -152,13 +152,14 @@ jobs:
BASE_VERSION="${VERSION%%-*}"
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
uv pip install \
--torch-backend cpu \
--index-url https://test.pypi.org/simple/ \
--extra-index-url https://pypi.org/simple \
--index-strategy unsafe-best-match \
"lerobot[all]==$BASE_VERSION"
else
echo "Installing release version $VERSION from PyPI..."
uv pip install "lerobot[all]==$VERSION"
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
fi
- name: Check lerobot version
run: uv run python -c "import lerobot; print(lerobot.__version__)"
+2 -2
View File
@@ -19,8 +19,8 @@ on:
workflow_dispatch:
# Runs at 02:00
schedule:
- cron: "0 2 * * *"
# schedule:
# - cron: "0 2 * * *"
env:
CLOSE_ISSUE_MESSAGE: >
+2
View File
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
All policies typically train for **510 epochs** (see §7).
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
+1 -1
View File
@@ -109,7 +109,7 @@ lerobot-train \
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
## Inference & Evaluation
-288
View File
@@ -1,288 +0,0 @@
# Video benchmark
## Questions
What is the optimal trade-off between:
- maximizing loading time with random access,
- minimizing memory space on disk,
- maximizing success rate of policies,
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
How to encode videos?
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
How to decode videos?
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
**Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
### Encoding parameters
| parameter | values |
| ----------- | ------------------------------------------------------------ |
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
| **pix_fmt** | `yuv444p`, `yuv420p` |
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
- `1_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
## Metrics
**Data compression ratio (lower is better)**
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
**Loading time ratio (lower is better)**
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
**Average Mean Square Error (lower is better)**
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
**Average Peak Signal to Noise Ratio (higher is better)**
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
<!-- **Loss of a pretrained policy (higher is better)** (not available)
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
**Success rate after retraining (higher is better)** (not available)
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
## How the benchmark works
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
This gives a unique set of encoding parameters which is used to encode the episode.
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
These are then all concatenated to a single table ready for analysis.
## Caveats
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`
- `ffmpegio`
- `decord`
- `nvc`
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
## Install
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
## Adding a video decoder
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
You can easily add a new decoder to benchmark by adding it to this function in the script:
```diff
def decode_video_frames(
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
+ elif backend == ["your_decoder"]:
+ return your_decoder_function(
+ video_path, timestamps, tolerance_s, backend
+ )
else:
raise NotImplementedError(backend)
```
## Example
For a quick run, you can try these parameters:
```bash
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
--crf 10 40 None \
--timestamps-modes 1_frame 2_frames \
--backends pyav video_reader \
--num-samples 5 \
--num-workers 5 \
--save-frames 0
```
## Results
### Reproduce
We ran the benchmark with the following parameters:
```bash
# h264 and h265 encodings
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav video_reader \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
# av1 encoding (only compatible with yuv420p and pyav decoder)
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
```
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
### Parameters selected for LeRobotDataset
Considering these results, we chose what we think is the best set of encoding parameter:
- vcodec: `libsvtav1`
- pix-fmt: `yuv420p`
- g: `2`
- crf: `30`
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
### Summary
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
-488
View File
@@ -1,488 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assess the performance of video decoding in various configurations.
This script will benchmark different video encoding and decoding parameters.
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
"""
import argparse
import datetime as dt
import itertools
import random
import shutil
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import einops
import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames,
encode_video_frames,
)
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.utils import TimerManager
BASE_ENCODING = OrderedDict(
[
("vcodec", "libx264"),
("pix_fmt", "yuv444p"),
("g", 2),
("crf", None),
# TODO(aliberts): Add fastdecode
# ("fastdecode", 0),
]
)
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
def parse_int_or_none(value) -> int | None:
if value.lower() == "none":
return None
try:
return int(value)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
def get_directory_size(directory: Path) -> int:
total_size = 0
for item in directory.rglob("*"):
if item.is_file():
total_size += item.stat().st_size
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
frames.append(frame)
return torch.stack(frames)
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
return
save_dir.mkdir(parents=True, exist_ok=True)
for i, ts in enumerate(timestamps):
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
return
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
case "1_frame":
frame_indexes = [idx]
case "2_frames":
frame_indexes = [idx - 1, idx]
case "2_frames_4_space":
frame_indexes = [idx - 5, idx]
case "6_frames":
frame_indexes = [idx - i for i in range(6)][::-1]
case _:
raise ValueError(timestamps_mode)
return [idx / fps for idx in frame_indexes]
def benchmark_decoding(
imgs_dir: Path,
video_path: Path,
timestamps_mode: str,
backend: str,
ep_num_images: int,
fps: int,
num_samples: int = 50,
num_workers: int = 4,
save_frames: bool = False,
) -> dict:
def process_sample(sample: int, lock: Lock):
time_benchmark = TimerManager(log=False)
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
num_frames = len(timestamps)
result = {
"psnr_values": [],
"ssim_values": [],
"mse_values": [],
}
with time_benchmark, lock:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
with time_benchmark:
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
)
if save_frames and sample == 0:
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
return result
load_times_video_ms = []
load_times_images_ms = []
mse_values = []
psnr_values = []
ssim_values = []
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
# Use a single shared lock for all worker threads
shared_lock = Lock()
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
psnr_values.extend(result["psnr_values"])
ssim_values.extend(result["ssim_values"])
mse_values.extend(result["mse_values"])
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
return {
"avg_load_time_video_ms": avg_load_time_video_ms,
"avg_load_time_images_ms": avg_load_time_images_ms,
"video_images_load_time_ratio": video_images_load_time_ratio,
"avg_mse": float(np.mean(mse_values)),
"avg_psnr": float(np.mean(psnr_values)),
"avg_ssim": float(np.mean(ssim_values)),
}
def benchmark_encoding_decoding(
dataset: LeRobotDataset,
video_path: Path,
imgs_dir: Path,
encoding_cfg: dict,
decoding_cfg: dict,
num_samples: int,
num_workers: int,
save_frames: bool,
overwrite: bool = False,
seed: int = 1337,
) -> list[dict]:
fps = dataset.fps
if overwrite or not video_path.is_file():
tqdm.write(f"encoding {video_path}")
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=encoding_cfg["vcodec"],
pix_fmt=encoding_cfg["pix_fmt"],
g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
# fast_decode=encoding_cfg.get("fastdecode"),
overwrite=True,
)
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)
video_images_size_ratio = video_size_bytes / images_size_bytes
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
timestamps_mode,
backend,
ep_num_images,
fps,
num_samples,
num_workers,
save_frames,
)
benchmark_row.update(
**{
"repo_id": dataset.repo_id,
"resolution": f"{width} x {height}",
"num_pixels": num_pixels,
"video_size_bytes": video_size_bytes,
"images_size_bytes": images_size_bytes,
"video_images_size_ratio": video_images_size_ratio,
"timestamps_mode": timestamps_mode,
"backend": backend,
},
**encoding_cfg,
)
benchmark_table.append(benchmark_row)
return benchmark_table
def main(
output_dir: Path,
repo_ids: list[str],
vcodec: list[str],
pix_fmt: list[str],
g: list[int],
crf: list[int],
# fastdecode: list[int],
timestamps_modes: list[str],
backends: list[str],
num_samples: int,
num_workers: int,
save_frames: bool,
):
check_datasets_formats(repo_ids)
encoding_benchmarks = {
"g": g,
"crf": crf,
# "fastdecode": fastdecode,
}
decoding_benchmarks = {
"timestamps_modes": timestamps_modes,
"backends": backends,
}
headers = ["repo_id", "resolution", "num_pixels"]
headers += list(BASE_ENCODING.keys())
headers += [
"timestamps_mode",
"backend",
"video_size_bytes",
"images_size_bytes",
"video_images_size_ratio",
"avg_load_time_video_ms",
"avg_load_time_images_ms",
"video_images_load_time_ratio",
"avg_mse",
"avg_psnr",
"avg_ssim",
]
file_paths = []
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
benchmark_table = []
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
dataset = LeRobotDataset(repo_id)
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for duet in [
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
for unique_combination in itertools.product(*encoding_benchmarks.values())
]:
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
for key, value in duet.items():
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
# Save intermediate results
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
now = dt.datetime.now()
csv_path = (
output_dir
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
)
benchmark_df.to_csv(csv_path, header=True, index=False)
file_paths.append(csv_path)
del benchmark_df
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_df.to_csv(concatenated_path, header=True, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/video_benchmark"),
help="Directory where the video benchmark outputs are written.",
)
parser.add_argument(
"--repo-ids",
type=str,
nargs="*",
default=[
"lerobot/pusht_image",
"lerobot/aloha_mobile_shrimp_image",
"lerobot/paris_street",
"lerobot/kitchen",
],
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
)
parser.add_argument(
"--vcodec",
type=str,
nargs="*",
default=["h264", "hevc", "libsvtav1"],
help="Video codecs to be tested",
)
parser.add_argument(
"--pix-fmt",
type=str,
nargs="*",
default=["yuv444p", "yuv420p"],
help="Pixel formats (chroma subsampling) to be tested",
)
parser.add_argument(
"--g",
type=parse_int_or_none,
nargs="*",
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
help="Group of pictures sizes to be tested.",
)
parser.add_argument(
"--crf",
type=parse_int_or_none,
nargs="*",
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
help="Constant rate factors to be tested.",
)
# parser.add_argument(
# "--fastdecode",
# type=int,
# nargs="*",
# default=[0, 1],
# help="Use the fastdecode tuning option. 0 disables it. "
# "For libx264 and libx265/hevc, only 1 is possible. "
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
# )
parser.add_argument(
"--timestamps-modes",
type=str,
nargs="*",
default=[
"1_frame",
"2_frames",
"2_frames_4_space",
"6_frames",
],
help="Timestamps scenarios to be tested.",
)
parser.add_argument(
"--backends",
type=str,
nargs="*",
default=["torchcodec", "pyav"],
help="Torchvision decoding backend to be tested.",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="Number of samples for each encoding x decoding config.",
)
parser.add_argument(
"--num-workers",
type=int,
default=10,
help="Number of processes for parallelized sample processing.",
)
parser.add_argument(
"--save-frames",
type=int,
default=0,
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
)
args = parser.parse_args()
main(**vars(args))
+1 -1
View File
@@ -35,7 +35,7 @@ USER root
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
libvulkan1 vulkan-tools \
&& mkdir -p /usr/share/vulkan/icd.d \
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
+1 -1
View File
@@ -18,7 +18,7 @@
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
# Configure the base image for CI with GPU access
ARG CUDA_VERSION=12.6.3
ARG CUDA_VERSION=12.8.1
ARG OS_VERSION=24.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
+9 -9
View File
@@ -8,7 +8,7 @@
- local: il_robots
title: Imitation Learning for Robots
- local: bring_your_own_policies
title: Bring Your Own Policies
title: Adding a Policy
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
@@ -24,6 +24,12 @@
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: hardware_guide
title: Compute Hardware Guide
- local: torch_accelerators
title: PyTorch accelerators
title: "Compute & Hardware"
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
@@ -31,10 +37,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: language_and_recipes
title: Language Columns and Recipes
- local: tools
title: Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
@@ -144,10 +148,6 @@
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
title: "Supported Hardware"
- sections:
- local: notebooks
title: Notebooks
+220 -81
View File
@@ -1,60 +1,37 @@
# Bring Your Own Policies
# Adding a Policy
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
## Step 1: Create a Policy Package
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
### Package Structure
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```
---
### Package Configuration
## Anatomy of a policy
Set up your `pyproject.toml`:
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
### Configuration class
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
```python
# configuration_my_custom_policy.py
# configuration_my_policy.py
from dataclasses import dataclass, field
from lerobot.configs import PreTrainedConfig
from lerobot.optim import AdamWConfig
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@PreTrainedConfig.register_subclass("my_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.
class MyPolicyConfig(PreTrainedConfig):
"""Configuration class for MyPolicy.
Args:
n_obs_steps: Number of observation steps to use as input
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
raise ValueError("n_action_steps cannot exceed horizon")
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
"""Validate input/output feature compatibility.
Call this explicitly from your policy's __init__ — the base class does not.
"""
if not self.image_features:
raise ValueError("MyCustomPolicy requires at least one image feature.")
raise ValueError("MyPolicy requires at least one image feature.")
if self.action_feature is None:
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
raise ValueError("MyPolicy requires 'action' in output_features.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self):
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
return None
@property
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Relative timestep offsets for the action chunk the dataset loader returns.
"""
"""Relative timestep offsets for the action chunk the dataset loader returns."""
return list(range(self.horizon))
@property
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
return None
```
## Step 3: Implement the Policy Class
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
### Policy class
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
```python
# modeling_my_custom_policy.py
# modeling_my_policy.py
import torch
import torch.nn as nn
from typing import Any
from lerobot.policies import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .configuration_my_policy import MyPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
name = "my_custom_policy"
class MyPolicy(PreTrainedPolicy):
config_class = MyPolicyConfig # must match the string in @register_subclass
name = "my_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
config.validate_features() # not called automatically by the base class
self.config = config
self.model = ... # your nn.Module here
def reset(self):
"""Reset episode state."""
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
...
def get_optim_params(self) -> dict:
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
...
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return a single action for the current timestep (called at inference)."""
"""Return a single action for the current timestep (called every step at inference)."""
...
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
"""Compute the training loss.
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
logging-friendly Python natives (no tensors with gradients).
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
timesteps padded because the episode ended before `horizon` steps, you
timesteps padded because the episode ended before `horizon` steps; you
can exclude those from your loss.
"""
actions = batch[ACTION]
action_is_pad = batch.get("action_is_pad")
...
return {"loss": ...}
return loss, {"some_loss_component": some_loss_component.item()}
```
## Step 4: Add Data Processors
The methods called by the train/eval loops:
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
| Method | Used by | What it does |
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
### Processor functions
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
```python
# processor_my_custom_policy.py
# processor_my_policy.py
from typing import Any
import torch
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
def make_my_custom_policy_pre_post_processors(
def make_my_policy_pre_post_processors(
config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
return preprocessor, postprocessor
```
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
**Important function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
## Step 5: Package Initialization
---
Expose your classes in the package's `__init__.py`:
## Path A: Out-of-tree plugin
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
### Package structure
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
```bash
lerobot_policy_my_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_policy/
├── __init__.py
├── configuration_my_policy.py
├── modeling_my_policy.py
└── processor_my_policy.py
```
### `pyproject.toml`
```toml
[project]
name = "lerobot_policy_my_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
### Package `__init__.py`
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
```python
# __init__.py
@@ -204,44 +239,148 @@ except ImportError:
"lerobot is not installed. Please install lerobot to use this policy package."
)
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
from .configuration_my_policy import MyPolicyConfig
from .modeling_my_policy import MyPolicy
from .processor_my_policy import make_my_policy_pre_post_processors
__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
"MyPolicyConfig",
"MyPolicy",
"make_my_policy_pre_post_processors",
]
```
## Step 6: Installation and Usage
### Install Your Policy Package
### Install and use
```bash
cd lerobot_policy_my_custom_policy
cd lerobot_policy_my_policy
pip install -e .
# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
pip install lerobot_policy_my_policy
```
### Use Your Policy
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
```bash
lerobot-train \
--policy.type my_custom_policy \
--policy.type my_policy \
--env.type pusht \
--steps 200000
```
## Examples and Community Contributions
---
## Path B: Contributing in-tree
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
### In-tree layout
```
src/lerobot/policies/my_policy/
├── __init__.py # re-exports config + modeling + processor factory
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
├── processor_my_policy.py # make_my_policy_pre_post_processors
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
```
Two notes:
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
### Wiring
Three places need to know about your policy. All by name.
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
Mirror an existing policy that's structurally similar to yours; the diff is small.
### Heavy / optional dependencies
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
```python
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
else:
DDIMScheduler = None # keeps the symbol bindable at import time
class DiffusionPolicy(PreTrainedPolicy):
def __init__(self, config):
require_package("diffusers", extra="diffusion")
super().__init__(config)
...
```
This way:
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
- Type checkers see the real symbol.
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
### Benchmarks and a published checkpoint
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
```markdown
## Results
Evaluated on LIBERO with `lerobot/<policy>_libero`:
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 87.5% | 50 |
| libero_object | 93.0% | 50 |
| libero_goal | 81.5% | 50 |
| libero_10 | 62.0% | 50 |
| **average** | **81.0%** | 200 |
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
```
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
### PR checklist
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
---
## Examples and community contributions
Check out these example policy implementations:
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
Share your policy implementations with the community! 🤗
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
+277
View File
@@ -0,0 +1,277 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+98
View File
@@ -0,0 +1,98 @@
# Compute HW Guide for LeRobot Training
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
## Memory by policy group
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30100% over a forward+backward pass alone.
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
| Light BC | `act`, `vqbet`, `tdmpc` | ~26GB | Laptop GPU (RTX 3060), L4, A10G |
| Diffusion | `diffusion`, `multi_task_dit` | ~814GB | RTX 4070+ / L4 / A10G |
| Small VLA | `smolvla` | ~1016GB | RTX 4080+ / L4 / A10G |
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~2440GB | A100 40 GB+ (24 GB tight at BS 1) |
| Multimodal | `groot`, `eo1` | ~2440GB | A100 40 GB+ |
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
## Training time
Robotics imitation learning typically converges in **510 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
```text
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
total_steps = epochs × steps_per_epoch
wall_clock ≈ total_steps × per_step_time
```
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
### Common scenarios
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
| Setup | Policy | Batch | Wall-clock |
| ------------------------------------ | -------------- | ----- | ---------: |
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~3060min |
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~24h |
| Single L4 / A10G (24 GB) | `act` | 8 | ~12h |
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~36h |
| Single A100 40 GB | `smolvla` | 16 | ~12h |
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~48h |
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~3060min |
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~12h |
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~614h |
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
### Multi-GPU matters a lot
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
### Schedule and checkpoints
If you shorten training (e.g. 5k10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
## Where to run
VRAM is the first filter. Within a tier, pick by budget and availability — the `$``$$$$` columns are relative; check current pricing on the provider you actually use.
| Class | VRAM | Tier | Comfortable for |
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
| L4 / A10G (cloud) | 24 GB | `$$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
### Hugging Face Jobs
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
```bash
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
bash -c "nvidia-smi && lerobot-train \
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
```
Notes:
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
+40 -37
View File
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
<!-- prettier-ignore-start -->
```python
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
```
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
**Example Configuration**
```json
"end_effector_bounds": {
"max": [0.24, 0.20, 0.10],
"min": [0.16, -0.08, 0.03]
{
"env": {
"processor": {
"inverse_kinematics": {
"end_effector_bounds": {
"max": [0.24, 0.2, 0.1],
"min": [0.16, -0.08, 0.03]
}
}
}
}
}
```
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
<!-- prettier-ignore-start -->
```python
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
urdf_path: str | None = None
target_frame_name: str | None = None
# bounds for the end-effector in x,y,z direction
end_effector_bounds: dict[str, list[float]] | None = None
# maximum step size for the end-effector in x,y,z direction
end_effector_step_sizes: dict[str, float] | None = None
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)
class HILSerlProcessorConfig:
...
# maximum gripper position that the gripper will be open at
max_gripper_pos: float | None = 100.0
```
<!-- prettier-ignore-end -->
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
**Collecting a Dataset for the reward classifier**
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
@@ -658,7 +661,7 @@ Example configuration section for data collection:
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes_to_record": 20,
"replay_episode": null,
@@ -671,7 +674,7 @@ Example configuration section for data collection:
**Reward Classifier Configuration**
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
- **model_type**: `"cnn"` or `"transformer"`
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"repo_id": "hf_username/dataset_name",
"root": null
},
"policy": {
"reward_model": {
"type": "reward_classifier",
"model_name": "helper2424/resnet10",
"model_type": "cnn",
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
"dropout_rate": 0.1,
"learning_rate": 1e-4,
"device": "cuda",
"use_amp": true,
"input_features": {
"observation.images.front": {
"type": "VISUAL",
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
**Configuration Setup**
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
1. Configure the policy settings (`type="sac"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
3. Set `dataset` to your cropped dataset
4. Configure environment settings with crop parameters
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`temperature_init`** (`algorithm.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
+50
View File
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
### PyTorch CUDA variant (Linux only)
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
<!-- prettier-ignore-start -->
<hfoptions id="cuda_variant">
<hfoption id="uv-source">
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
To override for a different CUDA variant:
```bash
uv pip install --force-reinstall torch torchvision \
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
```
</hfoption>
<hfoption id="pip-conda">
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
To pick a specific CUDA variant:
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
```bash
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
pip install -e ".[all]" # source
# — or —
pip install lerobot # from PyPI
```
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
```bash
uv pip install --torch-backend cu128 lerobot
```
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
### Troubleshooting
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
-147
View File
@@ -1,147 +0,0 @@
# Language columns and recipes
Most LeRobot datasets ship with a single `task` string per episode — fine for
short, single-instruction skills, but not enough for the longer-horizon,
multi-modal robot policies the field is moving toward (high-level planning,
memory, interjections, VQA, tool use). To support those policies without
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
language columns and a small recipe layer that turns those rows into
chat-style training samples on the fly.
The design splits cleanly into three layers:
1. **Data in the dataset** — language annotations stored next to frames in
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
and `language_events`). Datasets without these columns keep their existing
behavior.
2. **Recipe** — a YAML file that declares which annotation rows to bind and
how to lay them out as chat turns (`role`, `content`, optional images,
optional tool calls). Recipes are pure config; no Python required to add a
new one.
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
recipe against the per-frame annotations and emits HF-style `messages` plus
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
that policy processors consume.
This page describes each layer in turn.
## Layer 1 — language columns in the dataset
The two optional columns live next to frame data in
`data/chunk-*/file-*.parquet`:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text
role: string
content: string | null
style: string | null
timestamp: float64 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
the matching `observation.images.*` feature key. Rows of every other style —
including `motion`, which describes robot-frame primitives in joint / Cartesian
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
### Architecture
The language stack itself has three internal modules backing layer 1:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
### Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
### View-dependent resolution
For view-dependent styles (`vqa` and `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- {
role: assistant,
content: "${vqa}",
stream: high_level,
target: true,
if_present: vqa,
}
```
Add one such sub-recipe per camera the dataset records.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
## Layer 3 — training format
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
+4 -2
View File
@@ -28,13 +28,15 @@ lerobot-train \
--steps=100000 \
--batch_size=32 \
--peft.method_type=LORA \
--peft.r=64
--peft.r=64 \
--peft.lora_alpha=64
```
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
the closer you get to full fine-tuning
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
-210
View File
@@ -1,210 +0,0 @@
# Tools
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
emit structured invocations like `say(text="OK, starting now")` that the
runtime dispatches to a real implementation (TTS, controller, logger, …).
This page covers:
1. Where the tool catalog lives.
2. How the annotation pipeline produces tool-call atoms.
3. How to add your own tool.
## Where tools are declared
Two layers.
**The catalog** — a list of OpenAI-style function schemas — lives at
`meta/info.json["tools"]` on each dataset. Example:
```json
{
"features": { "...": "..." },
"tools": [
{
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak."
}
},
"required": ["text"]
}
}
}
]
}
```
Read it via the dataset metadata accessor:
```python
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
tools = meta.tools # list[dict] — OpenAI tool schemas
```
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
single-entry list with the canonical `say` schema. So unannotated
datasets and chat-template consumers keep working without any
configuration:
```python
prompt_str = tokenizer.apply_chat_template(
sample["messages"],
tools=meta.tools, # works either way
add_generation_prompt=False,
tokenize=False,
)
```
**The implementations** — runnable Python — will live under
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
not part of the catalog layer described here; today this layer ships
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
## Per-row tool _invocations_
The catalog above describes _what can be called_. The actual _call_ — the
function name plus the argument values — is stored per-row, on the
assistant atoms in `language_events`:
```python
{
"role": "assistant",
"content": null,
"style": null,
"timestamp": 12.4,
"camera": null,
"tool_calls": [
{ "type": "function",
"function": { "name": "say", "arguments": { "text": "On it." } } }
]
}
```
Recipes splice these into rendered messages via `tool_calls_from`:
```yaml
user_interjection_response:
bindings:
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- { role: user, content: "${task}", stream: high_level }
- {
role: assistant,
content: "${current_plan}",
stream: high_level,
target: true,
tool_calls_from: speech,
}
```
The model's training target is one assistant turn that carries both the
plan text _and_ the `say` tool call. At inference, the runtime parses
the generated text back into structured `tool_calls` and dispatches to
the matching implementation.
## How to add your own tool
> **Note:** Steps 2 and 3 below describe the runtime layer
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
> `get_tools(meta)`) which is not part of the catalog layer shipped
> today — those modules don't yet exist in the tree. Step 1 alone is
> enough to make the tool visible to the chat template via
> `meta.tools` so the model can learn to _generate_ the call;
> executing the call at inference requires the runtime layer.
Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control
loop.
### Step 1 — declare the schema
Add an entry under `meta/info.json["tools"]`. Either edit the file
directly on disk _before_ running the annotation pipeline (it'll be
preserved) or hand it to `lerobot-annotate` via a config flag.
```json
{
"tools": [
{ "type": "function", "function": { "name": "say", "...": "..." } },
{
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a high-resolution still image for the user.",
"parameters": {
"type": "object",
"properties": {
"label": {
"type": "string",
"description": "Short label for the saved image."
}
},
"required": ["label"]
}
}
}
]
}
```
The schema follows OpenAI's function-calling convention exactly, so the
chat template can render it natively.
### Step 2 — implement the call
Create `src/lerobot/tools/record_observation.py`:
```python
from .base import Tool
from typing import Any
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
class RecordObservationTool:
name = "record_observation"
schema = RECORD_OBSERVATION_SCHEMA
def __init__(self, schema: dict | None = None, output_dir: str = "."):
self.output_dir = output_dir
def call(self, arguments: dict) -> str:
label = arguments["label"]
# ... save the latest camera frame to <output_dir>/<label>.png ...
return f"saved {label}.png"
```
One file per tool keeps dependencies isolated — `record_observation`
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
only the tools they need avoid heavy transitive deps.
### Step 3 — register it
Add to `src/lerobot/tools/registry.py`:
```python
from .record_observation import RecordObservationTool
TOOL_REGISTRY["record_observation"] = RecordObservationTool
```
That's it. At runtime `get_tools(meta)` looks up each schema in
`meta.tools`, instantiates the matching registered class, and returns
a name → instance dict the dispatcher can route into.
If you want to use a tool _without_ writing an implementation (e.g. for
training-time chat-template formatting only), step 1 alone is enough —
the model still learns to _generate_ the call. Steps 2 and 3 are only
needed to actually _execute_ it at inference.
+136
View File
@@ -0,0 +1,136 @@
# OMX Follower — Cube Pick And Place Example
This is an example of what is possible to do with LeRobot on a physical setup.
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
## Hardware
| Component | Value |
| --------- | ------------------------------------ |
| Robot | OMX Follower |
| Cameras | 2× OpenCV cameras (wrist + top-down) |
## Scripts
| Script | Purpose |
| ---------------------- | --------------------------------------------------------------- |
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
## Setup
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
```bash
export ROBOT_PORT=/dev/ttyACM0
export TELEOP_PORT=/dev/ttyACM1
export HF_USERNAME=<your_hf_username>
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
```
## Step 1 — Collect Data
```bash
lerobot-record \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--teleop.type=omx_leader \
--teleop.port=$TELEOP_PORT \
--teleop.id=omx_leader \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
### Bonus Auto-Collect script
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
```bash
python -m examples.omx.record_grab \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
Each episode:
1. The arm grabs the cube from the center of the workspace and places it at a random position.
2. The arm returns to HOME.
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
## Step 2 — Train
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
```bash
lerobot-train \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--policy.type=act \
--output_dir=outputs/train/omx_pickandplace_act \
--policy.device=cuda \
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
--steps=20000 \
--wandb.enable=true
```
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
## Step 3 — Rollout
Use the `lerobot-rollout` CLI with base strategy:
```bash
lerobot-rollout \
--strategy.type=base \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
```
For continuous recording with automatic upload (sentry mode):
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=10 \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
```
## Environment Reset Utility
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
`reset_environment.py` can be run standalone to prepare the workspace:
```bash
# Grab cube + place it at a random position on the left side
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
```
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
+422
View File
@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""
Auto-record grab episodes for the OMX robot arm.
Each episode cycle:
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
2. HOME — return arm to home with gripper open
3. record_grab — execute a targeted grab to the stored position while recording
observations + actions to a LeRobotDataset
Usage (run from repo root):
python -m examples.omx.record_grab \\
--robot.type=omx_follower \\
--robot.port=/dev/ttyACM0 \\
--robot.id=omx_follower \\
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
--dataset.repo_id=<hf_username>/<dataset_name> \\
--dataset.root=data/omx_grab \\
--dataset.num_episodes=50 \\
--dataset.single_task="Grab the cube" \\
--dataset.streaming_encoding=true
"""
import logging
from dataclasses import dataclass
from pprint import pformat
import numpy as np
from lerobot.cameras import CameraConfig # noqa: F401
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.processor import make_default_processors
from lerobot.robots import RobotConfig, make_robot_from_config
from lerobot.robots.omx_follower import OmxFollower
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from .reset_environment import (
APPROACH_SPEED,
GRIPPER_CLOSE_POS,
HOME_POSE,
PUSH_END_ELBOW_FLEX,
PUSH_END_SHOULDER_LIFT,
PUSH_START_ELBOW_FLEX,
PUSH_START_SHOULDER_LIFT,
array_to_pose,
grab_cube,
horizontal_wrist_flex,
move_to_pose,
place_cube,
pose_to_array,
)
# ── Grab-episode motion parameters ────────────────────────────────────────────
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
GRAB_RAISE_SL_OFFSET = 20.0
GRAB_LOWER_SPEED = 20.0
RECORD_SPEED = 30.0
# Pose the arm travels to after closing the gripper (cube held).
GRAB_CARRY_POSE = {
"shoulder_pan.pos": -23.0,
"shoulder_lift.pos": 5.0,
"elbow_flex.pos": 18.0,
"wrist_flex.pos": -14.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
}
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
# Cube-approach and carry poses are never jittered to preserve precision.
_JITTER_LIMITS: dict[str, float] = {
"shoulder_pan.pos": 5.0,
"shoulder_lift.pos": 4.0,
"elbow_flex.pos": 4.0,
"wrist_flex.pos": 3.0,
"wrist_roll.pos": 2.0,
"gripper.pos": 0.0,
}
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
"""Return a copy of pose with independent per-joint random perturbations."""
return {
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
}
def _random_stuck_pose(rng: np.random.Generator) -> dict:
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
table-safe envelope across the full sl range:
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
sl= 0 → ef ∈ [-25, 25] (mid reach)
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
wrist_flex is randomly offset from the horizontal value.
"""
pan = float(rng.uniform(-5.0, 35.0))
sl = float(rng.uniform(-50.0, 30.0))
if sl <= 0.0:
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
ef_lo = alpha * -25.0 # 0 → -25
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
else:
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
ef = float(rng.uniform(ef_lo, ef_hi))
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf,
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
"gripper.pos": GRIPPER_CLOSE_POS,
}
logger = logging.getLogger(__name__)
@dataclass
class OmxRecordGrabConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Resume recording on an existing dataset.
resume: bool = False
# Fraction of episodes that start from a random stuck pose (gripper closed) to
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
recovery_prob: float = 0.5
def record_episode_spline(
robot: OmxFollower,
waypoints: list[dict],
speeds: list[float],
dataset: LeRobotDataset,
task: str,
) -> None:
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
Segment durations are parameterized from the maximum absolute joint delta
between consecutive waypoints divided by the requested segment speed,
producing non-uniform timing in joint space. Interior tangents are derived
from the adjacent per-segment velocities, with clamped (zero-velocity)
endpoints so the arm starts and stops smoothly. Each segment is cubic
Hermite, giving C1 continuity at every waypoint.
"""
pts = [pose_to_array(w) for w in waypoints]
n = len(pts)
# Steps and duration per segment
n_steps_list = []
timestamps = []
for i in range(n - 1):
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
n_steps_list.append(ns)
timestamps.append(ns / dataset.fps)
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
vels = [np.zeros_like(pts[0])]
for i in range(1, n - 1):
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
vels.append(0.5 * (v_prev + v_next))
vels.append(np.zeros_like(pts[0]))
dt = 1.0 / dataset.fps
for seg in range(n - 1):
ns = n_steps_list[seg]
if ns == 0:
continue
p0, p1 = pts[seg], pts[seg + 1]
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
m0 = vels[seg] * timestamps[seg]
m1 = vels[seg + 1] * timestamps[seg]
for step in range(1, ns + 1):
t = step / ns
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = -2 * t**3 + 3 * t**2
h11 = t**3 - t**2
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
action = array_to_pose(commanded)
robot.send_action(action)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": task})
precise_sleep(dt)
def record_grab_episode(
robot: OmxFollower,
dataset: LeRobotDataset,
pan: float,
t: float,
task: str,
recovery_start: bool = False,
) -> None:
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
Normal sequence (initial HOME move is NOT recorded):
HOME → raised approach above cube → lower → close gripper
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
(gripper closed) without recording, then recording begins from there:
stuck_pose → raised approach above cube → [normal grab sequence from there]
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
"""
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
sl_raised = sl - GRAB_RAISE_SL_OFFSET
wf_horizontal = horizontal_wrist_flex(sl, ef)
rng = np.random.default_rng()
if recovery_start:
stuck_pose = _random_stuck_pose(rng)
logger.info(f"Recovery start: {stuck_pose}")
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
first_waypoints = [stuck_pose]
first_speeds = []
else:
jittery_start = _jitter_pose(HOME_POSE, rng)
move_to_pose(robot, jittery_start, APPROACH_SPEED)
first_waypoints = [jittery_start]
first_speeds = []
waypoints = first_waypoints + [
{ # raised approach: arm above cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # lower onto cube — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # close gripper — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
_jitter_pose(
{ # raise with cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
_jitter_pose(
{ # retract: fold arm toward HOME before sweeping to carry zone
"shoulder_pan.pos": pan * 0.25,
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
GRAB_CARRY_POSE, # no jitter: target drop zone
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
HOME_POSE,
]
speeds = first_speeds + [
RECORD_SPEED, # (HOME →) raised approach
GRAB_LOWER_SPEED, # raised approach → lower
GRAB_LOWER_SPEED, # lower → close gripper
RECORD_SPEED, # close gripper → raise
RECORD_SPEED, # raise → retract
RECORD_SPEED, # retract → carry pose
RECORD_SPEED, # carry pose → drop
RECORD_SPEED, # drop → HOME
]
record_episode_spline(robot, waypoints, speeds, dataset, task)
# Dwell at HOME for ~0.5 s before next episode
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
dt = 1.0 / dataset.fps
for _ in range(int(dataset.fps * 0.5)):
robot.send_action(HOME_POSE)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
dataset.add_frame({**obs_frame, **home_action, "task": task})
precise_sleep(dt)
@parser.wrap()
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger.info(pformat(cfg))
robot = make_robot_from_config(cfg.robot)
use_videos = cfg.dataset.video
teleop_action_processor, _, robot_obs_processor = make_default_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=robot.action_features),
use_videos=use_videos,
),
aggregate_pipeline_dataset_features(
pipeline=robot_obs_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=use_videos,
),
)
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
dataset = None
try:
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
else:
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=use_videos,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
robot.connect(calibrate=True)
rng = np.random.default_rng()
with VideoEncodingManager(dataset):
for episode_idx in range(cfg.dataset.num_episodes):
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
logger.info("Step 1: grabbing and placing cube...")
grab_cube(robot)
pan, t = place_cube(robot)
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
record_grab_episode(
robot,
dataset,
pan,
t,
cfg.dataset.single_task,
recovery_start=recovery_start,
)
dataset.save_episode()
logger.info(f"Episode {episode_idx + 1} saved.")
finally:
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
if __name__ == "__main__":
record_grab()
+267
View File
@@ -0,0 +1,267 @@
#!/usr/bin/env python3
"""
Auto-reset and cube-grab utility for the OMX robot arm.
Provides:
- grab_cube(robot): sweep workspace, center cube, close gripper
- place_cube(robot): carry cube to a random position, release
Standalone usage (run from repo root):
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
To read current joint values for calibration, add after robot.connect():
obs = robot.get_observation()
print({k: round(obs[k], 1) for k in JOINT_NAMES})
robot.disconnect(); raise SystemExit
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
Linear interpolation preserves this constraint between any two poses that satisfy it.
"""
import argparse
import logging
import numpy as np
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
from lerobot.robots.robot import Robot
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
# ── Poses ─────────────────────────────────────────────────────────────────────
HOME_POSE = {
"shoulder_pan.pos": 0.0,
"shoulder_lift.pos": -50.0,
"elbow_flex.pos": 50.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
SWEEP_WAYPOINTS = [
{
"shoulder_pan.pos": -60.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -20.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": -30.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": 20.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -55.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
]
# ── Motion parameters ─────────────────────────────────────────────────────────
CONTROL_HZ = 30
APPROACH_SPEED = 50.0
SWEEP_SPEED = 40.0
# ── Grab-sequence parameters ──────────────────────────────────────────────────
GRAB_PAN = 0.0
SWEEP_LEFT_PAN = -60.0
SWEEP_RIGHT_PAN = 60.0
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
SWEEP_END_PAN_RANGE = (15.0, 20.0)
SWEEP_LOW_SHOULDER_LIFT = 50.0
SWEEP_LOW_ELBOW_FLEX_START = -60.0
SWEEP_LOW_ELBOW_FLEX_END = -55.0
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
PUSH_START_SHOULDER_LIFT = 0.0
PUSH_START_ELBOW_FLEX = 45.0
PUSH_END_SHOULDER_LIFT = 50.0
PUSH_END_ELBOW_FLEX = -50.0
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
# Does not affect the grab-target interpolation in record_grab.py.
PUSH_RAISE_OFFSET = 5.0
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
GRIPPER_CLOSE_POS = 50.0
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
JOINT_NAMES = [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
]
# ── Helpers ───────────────────────────────────────────────────────────────────
def pose_to_array(pose: dict) -> np.ndarray:
return np.array([pose[k] for k in JOINT_NAMES])
def array_to_pose(arr: np.ndarray) -> dict:
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
sl = SWEEP_LOW_SHOULDER_LIFT
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
def _high_sweep_pose(pan: float) -> dict:
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": shoulder_lift,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
"wrist_roll.pos": 0.0,
"gripper.pos": gripper,
}
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
"""Interpolate from current position to target at the given speed (units/s)."""
obs = robot.get_observation()
current = np.array([obs[k] for k in JOINT_NAMES])
goal = pose_to_array(target)
max_distance = float(np.max(np.abs(goal - current)))
if max_distance < 0.5:
return
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
dt = 1.0 / CONTROL_HZ
for step in range(1, n_steps + 1):
t = step / n_steps
robot.send_action(array_to_pose(current + t * (goal - current)))
precise_sleep(dt)
# ── Sequences ─────────────────────────────────────────────────────────────────
def grab_cube(robot: Robot) -> None:
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
for pan, end_pan in [
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
]:
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
move_to_pose(
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
)
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Extending to push cube into gripper...")
move_to_pose(
robot,
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
APPROACH_SPEED,
)
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
SWEEP_SPEED,
)
logger.info("Closing gripper...")
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
APPROACH_SPEED,
)
logger.info("Grab complete.")
def place_cube(robot: Robot) -> tuple[float, float]:
"""Carry the cube (gripper closed) to a random position on the left side, then release.
Returns:
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
"""
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
t = float(np.random.uniform(*PLACE_REACH_RANGE))
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
move_to_pose(
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Place complete.")
return pan, t
# ── Entry point ───────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
parser.add_argument("--port", default="/dev/ttyACM1")
parser.add_argument("--robot_id", default="omx_follower")
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
robot.connect(calibrate=True)
try:
if args.mode == "grab":
grab_cube(robot)
elif args.mode == "grab_and_place":
grab_cube(robot)
place_cube(robot)
finally:
robot.disconnect()
if __name__ == "__main__":
main()
+24 -21
View File
@@ -4,13 +4,13 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -28,7 +28,7 @@ def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
policy_learner: GaussianActorPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
@@ -40,8 +40,9 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers_and_scheduler()
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -83,24 +84,26 @@ def run_learner(
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
def batch_iter(b=batch):
while True:
yield b
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats = algorithm.update(batch_iter())
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -113,7 +116,7 @@ def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
policy_actor: GaussianActorPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
@@ -144,15 +147,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
# Get action from policy (returns full action: continuous + discrete)
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action_tensor = policy_actor.select_action(policy_obs)
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
@@ -261,14 +264,14 @@ def main():
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
policy_cfg = GaussianActorConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
policy_actor = GaussianActorPolicy(policy_cfg)
policy_learner = GaussianActorPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
+30 -5
View File
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Core ML
"torch>=2.7,<2.11.0",
"torchvision>=0.22.0,<0.26.0",
"torch>=2.7,<2.12.0",
"torchvision>=0.22.0,<0.27.0",
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0",
"Pillow>=10.0.0,<13.0.0",
@@ -95,11 +95,22 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<5.0.0",
"datasets>=4.0.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
# NOTE: torchcodec wheel availability matrix (PyPI):
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
"jsonlines>=4.0.0,<5.0.0",
]
training = [
@@ -195,7 +206,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -293,6 +304,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
# uv pip install --force-reinstall torch torchvision \
# --index-url https://download.pytorch.org/whl/cu130
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
+1
View File
@@ -99,6 +99,7 @@ def save_checkpoint(
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
-4
View File
@@ -24,7 +24,6 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -44,10 +43,7 @@ __all__ = [
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
]
+6
View File
@@ -117,3 +117,9 @@ class PeftConfig:
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
# fine-tuning.
r: int = 16
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
# In general, a higher alpha means stronger adaptation signal.
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
# Common values are r (alpha == rank) or 2*r.
lora_alpha: int | None = None
+5 -2
View File
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
yaml_overrides = parser.get_yaml_overrides("policy")
cli_overrides = parser.get_cli_overrides("policy") or []
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=yaml_overrides + cli_overrides
)
self.policy.pretrained_path = Path(policy_path)
else:
+83 -1
View File
@@ -13,8 +13,10 @@
# limitations under the License.
import importlib
import inspect
import json
import pkgutil
import sys
import tempfile
from argparse import ArgumentError
from collections.abc import Callable, Iterable, Sequence
from functools import wraps
@@ -24,6 +26,7 @@ from types import ModuleType
from typing import Any, TypeVar, cast
import draccus
import yaml # type: ignore[import-untyped]
from lerobot.utils.utils import has_method
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
# Storage for path args extracted from YAML/JSON config files, so that
# get_path_arg() can find them even when they weren't passed via CLI.
_config_path_args: dict[str, str] = {}
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
_config_yaml_overrides: dict[str, list[str]] = {}
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
args = []
for key, value in d.items():
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
continue
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, bool):
value = str(value).lower()
if isinstance(value, dict):
args.extend(_flatten_to_cli_args(value, full_key))
elif value is not None and not isinstance(value, list):
args.append(f"--{full_key}={value}")
return args
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
"""Parses arguments from cli at a given nested attribute level.
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
return parse_arg(f"{field_name}.{PATH_KEY}", args)
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
if result is None:
result = _config_path_args.get(field_name)
return result
def get_yaml_overrides(field_name: str) -> list[str]:
return _config_yaml_overrides.get(field_name, [])
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
@@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
draccus will fail because ``path`` is not a valid field on policy config classes.
This function extracts those path values, stores them in ``_config_path_args`` for
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
"""
config_file = Path(config_path)
suffix = config_file.suffix.lower()
if suffix in (".yaml", ".yml"):
with open(config_file) as f:
config_data = yaml.safe_load(f)
elif suffix == ".json":
with open(config_file) as f:
config_data = json.load(f)
else:
return config_path
if not isinstance(config_data, dict):
return config_path
modified = False
for field in path_fields:
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
else:
del config_data[field]
modified = True
if not modified:
return config_path
# Write cleaned config to a temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
if suffix in (".yaml", ".yml"):
yaml.dump(config_data, tmp, default_flow_style=False)
else:
json.dump(config_data, tmp, indent=2)
return tmp.name
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
"""
HACK: Similar to draccus.wrap but does three additional things:
@@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
if has_method(argtype, "__get_path_fields__"):
path_fields = argtype.__get_path_fields__()
cli_args = filter_path_args(path_fields, cli_args)
# Also extract path fields from the YAML/JSON config file
if config_path_cli:
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
if has_method(argtype, "from_pretrained") and config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
-206
View File
@@ -1,206 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
"""``${name}`` placeholder pattern used by both recipe binding-reference
discovery (here) and rendered-message substitution (in ``language_render``)."""
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
# ``stream`` is typed Optional only so the dataclass can keep its
# field ordering, but recipes must always tag every turn with a
# stream — the renderer's ``_validate_rendered`` would reject
# ``None`` later on. Fail at construction so the bad recipe is
# caught at YAML load time rather than at the first sample.
if self.stream is None:
raise ValueError(
f"MessageTurn(role={self.role!r}) is missing a stream — "
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
)
if self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)
+8 -10
View File
@@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin):
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
yaml_overrides = parser.get_yaml_overrides("policy")
cli_overrides = parser.get_cli_overrides("policy") or []
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=yaml_overrides + cli_overrides
)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
config_path = parser.parse_arg("config_path")
@@ -256,7 +259,9 @@ class TrainPipelineConfig(HubMixin):
) from e
cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
@@ -267,10 +272,3 @@ class TrainPipelineConfig(HubMixin):
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
-14
View File
@@ -37,14 +37,6 @@ from .dataset_tools import (
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -61,15 +53,10 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"add_features",
@@ -79,7 +66,6 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
+1 -1
View File
@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] in {"string", "language"}:
if features[key]["dtype"] == "string":
continue
if features[key]["dtype"] in ["image", "video"]:
+27 -33
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable
from pathlib import Path
import numpy as np
@@ -34,6 +35,7 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_stats,
@@ -174,6 +176,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -187,6 +190,29 @@ class LeRobotDatasetMetadata:
if self.episodes is None:
self._load_metadata()
def filter_episodes(
self,
predicate: Callable[[dict], bool],
candidates: list[int] | None = None,
) -> list[int]:
"""Filter episodes whose metadata satisfies a given predicate.
Args:
predicate: Predicate over per-episode metadata rows used to select episodes.
candidates: Optional list of episode indices to restrict evaluation to.
Returns:
List of sorted episode indices that satisfy the predicate.
"""
self.ensure_readable()
if candidates is not None:
candidate_set = set(candidates)
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
else:
combined = predicate
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
return sorted(int(idx) for idx in filtered["episode_index"])
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
@@ -316,39 +342,6 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def has_language_columns(self) -> bool:
"""Return ``True`` if the dataset declares any language column.
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property
def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset.
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
can mutate the result safely. Falls back to
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
``say`` schema) when the dataset doesn't declare any — that way
unannotated datasets and chat-template consumers
(``apply_chat_template(messages, tools=meta.tools)``) keep
working out of the box.
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
declared = self.info.tools
if declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -664,6 +657,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
+5
View File
@@ -295,4 +295,9 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item
+1 -15
View File
@@ -22,12 +22,6 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -52,13 +46,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -254,8 +242,6 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
+11 -6
View File
@@ -31,10 +31,10 @@ from torchvision import transforms
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
from .language import LANGUAGE_COLUMNS
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -186,6 +186,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -257,13 +265,11 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in LANGUAGE_COLUMNS:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None or isinstance(first_item, dict):
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -298,9 +304,8 @@ def item_to_torch(item: dict) -> dict:
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
skip_keys = {"task", *LANGUAGE_COLUMNS}
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
-240
View File
@@ -1,240 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
# Project-local styles can be registered at import time by appending to
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
# is intentionally NOT in this set: motion primitives are described in
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
# view-dependent. The ``camera`` field nevertheless lives on
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
# behave symmetrically across the two columns; persistent rows simply
# always have ``camera=None`` in practice today.
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
# --- Tool registry --------------------------------------------------------
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
# of OpenAI-style function schemas. The runtime / training stack reads them
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
# fallback when the dataset doesn't declare any). Implementations live
# under :mod:`lerobot.tools` (one file per tool); see
# ``docs/source/tools.mdx`` for the authoring guide.
SAY_TOOL_SCHEMA: dict = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
"""Canonical schema for the ``say`` tool emitted by the steerable
annotation pipeline (PR 2 Module 2). Single source of truth PR 2's
writer, PR 3's runtime tool registry, and the dataset visualizer all
import this constant rather than duplicating the dict."""
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
chat-template consumers (``apply_chat_template(messages, tools=...)``)
keep working out of the box."""
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")
-543
View File
@@ -1,543 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
from .language import LANGUAGE_PERSISTENT, column_for_style
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
``camera`` selector. Only valid for persistent styles.
"""
_validate_persistent_resolver("active_at", style)
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) <= t
]
if not matches:
return None
latest_ts = max(_timestamp(row) for row in matches)
return _select_one(
[row for row in matches if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
camera=camera,
)
EMITTED_AT_TOLERANCE_S = 0.1
"""Half-window for matching persistent rows to a frame timestamp in
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
is also a float64 from parquet, so in the ideal hot path an exact match
would suffice but any caller that derives ``t`` arithmetically (e.g.
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
common arithmetic drift without admitting frames that are visibly far
apart at typical control rates (30100 Hz)."""
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
we use a tolerance instead of bit-equality). For event styles, the
``events`` list is assumed to come from the dataset row at frame ``t``
(event rows carry no timestamp of their own), so all matching event rows
are considered emitted at ``t``. ``camera`` filters by the row's
``camera`` field required to disambiguate when multiple view-dependent
rows share ``(t, role)`` across cameras.
"""
if column_for_style(style) == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
]
else:
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(
task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None:
return task
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
if name == "emitted_at":
return emitted_at(t, persistent=persistent, events=events, **kwargs)
if name == "active_at":
return active_at(t, persistent=persistent, **kwargs)
if name == "nth_prev":
return nth_prev(t, persistent=persistent, **kwargs)
if name == "nth_next":
return nth_next(t, persistent=persistent, **kwargs)
raise ValueError(f"Unknown language resolver: {name!r}")
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
# ``stream`` is enforced non-None at MessageTurn construction time
# (see ``MessageTurn.__post_init__``), so a missing stream here would
# mean the dataclass invariant was bypassed; no need to re-check.
def _nth_relative(
name: str,
t: float,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(name, style)
if abs(offset) < 1:
raise ValueError(f"{name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_row_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{name} requires a persistent style.")
if column_for_style(style) != LANGUAGE_PERSISTENT:
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
]
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the resolver is ambiguous.
Multiple matches always raise even when the caller already passed
some selectors because remaining ambiguity means the data has
several rows that look identical to the resolver and the caller
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
rows shared across cameras).
"""
if not rows:
return None
if len(rows) > 1:
raise ValueError(
f"Ambiguous resolver for style={style!r} role={role!r} "
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
f"Add a selector that distinguishes them."
)
return rows[0]
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Stable sort key for both persistent and event rows.
Event rows lack ``timestamp`` (it is implicit in the frame), so default
to ``0.0`` within a single frame all event rows share the same sort
bucket and are tiebroken by ``(style, role)``.
"""
timestamp = row.get("timestamp")
ts = (
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
)
return (ts, row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized
+23 -1
View File
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
episode_filter: Callable[[dict], bool] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
@@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
@@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root
self.revision = self.meta.revision
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
):
logger.warning(
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
)
if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
if not resolved:
raise ValueError(
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
)
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
episodes = resolved
self.episodes = episodes
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
meta=self.meta,
+1 -7
View File
@@ -88,6 +88,7 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -129,9 +130,6 @@ class DatasetInfo:
# Optional metadata
robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict)
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
tools: list[dict] | None = None
def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation
@@ -153,15 +151,11 @@ class DatasetInfo:
"""Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them.
Drops ``tools`` when unset so existing datasets keep a clean
``info.json``.
"""
d = dataclasses.asdict(self)
for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"])
if d.get("tools") is None:
d.pop("tools", None)
return d
@classmethod
+63 -56
View File
@@ -33,7 +33,6 @@ import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
@@ -132,7 +131,9 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
@@ -145,85 +146,87 @@ def decode_video_frames(
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
)
elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
def decode_video_frames_pyav(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
"""Loads frames associated to the requested timestamps of a video using PyAV.
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
x86_64 and linux armv7l see the torchcodec block in pyproject.toml for the full matrix).
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
accurate seek.
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
forward until we have covered the requested timestamp range. The number of key frames in a
video can be adjusted at encoding time to trade off decoding speed against file size.
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Args:
video_path: Path to the video file.
timestamps: List of timestamps (in seconds) to extract frames for.
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
[0, 1] range.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
Returns:
torch.Tensor of shape (len(timestamps), C, H, W).
"""
video_path = str(video_path)
# set backend
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
video_path = str(video_path)
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
with av.open(video_path) as container:
stream = container.streams.video[0]
container.seek(int(first_ts * av.time_base), backward=True)
if backend == "pyav":
reader.container.close()
for frame in container.decode(stream):
if frame.pts is None:
continue
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
# Convert to CHW uint8 to match torchcodec's output layout.
arr = frame.to_ndarray(format="rgb24") # H, W, 3
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
reader = None
if not loaded_frames:
raise FrameTimestampError(
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
)
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
loaded_ts_t = torch.tensor(loaded_ts)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
@@ -234,14 +237,14 @@ def decode_video_frames_torchvision(
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nloaded timestamps: {loaded_ts_t}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
f"\nbackend: pyav"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
closest_ts = loaded_ts_t[argmin_]
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
@@ -282,7 +285,11 @@ class VideoDecoderCache:
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
+5 -5
View File
@@ -18,13 +18,13 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .sac.configuration_sac import SACConfig as SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
@@ -32,21 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
__all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
"PI0Config",
"PI0FastConfig",
"PI05Config",
"SACConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+11 -11
View File
@@ -47,12 +47,12 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
@@ -88,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -127,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from .sac.modeling_sac import SACPolicy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
return SACPolicy
return GaussianActorPolicy
elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy
@@ -167,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -191,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "groot":
@@ -365,10 +365,10 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from .sac.processor_sac import make_sac_pre_post_processors
elif isinstance(policy_cfg, GaussianActorConfig):
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
processors = make_sac_pre_post_processors(
processors = make_gaussian_actor_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACConfig
from .modeling_sac import SACPolicy
from .processor_sac import make_sac_pre_post_processors
from .configuration_gaussian_actor import GaussianActorConfig
from .modeling_gaussian_actor import GaussianActorPolicy
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
@@ -75,18 +75,19 @@ class PolicyConfig:
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@PreTrainedConfig.register_subclass("gaussian_actor")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
class GaussianActorConfig(PreTrainedConfig):
"""Gaussian actor configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configures the policy-side (actor + observation encoder) of a Gaussian
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
By default the actor output is a tanh-squashed diagonal Gaussian
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
CLI: ``--policy.type=gaussian_actor``.
"""
# Mapping of feature types to normalization modes
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Encoder architecture
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
# Number of steps for online training
online_steps: int = 1000000
# Capacity of the online replay buffer
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
# Network architecture
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters (Gaussian head)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
# Any validation specific to GaussianActor configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
# this preset.
default_lr = 3e-4
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
"actor": {"lr": default_lr},
"critic": {"lr": default_lr},
"temperature": {"lr": default_lr},
},
)
@@ -15,16 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_sac import SACConfig, is_image_feature
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
class GaussianActorPolicy(
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
config_class = GaussianActorConfig
name = "gaussian_actor"
def __init__(
self,
config: SACConfig | None = None,
config: GaussianActorConfig | None = None,
):
super().__init__(config)
config.validate_features()
@@ -54,9 +49,8 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_temperature()
self._init_discrete_critic()
def get_optim_params(self) -> dict:
optim_params = {
@@ -65,11 +59,7 @@ class SACPolicy(
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
return optim_params
def reset(self):
@@ -79,7 +69,9 @@ class SACPolicy(
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
raise NotImplementedError(
"GaussianActorPolicy does not support action chunking. It returns single actions!"
)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -92,360 +84,43 @@ class SACPolicy(
actions, _, _ = self.actor(batch, observations_features)
if self.config.num_discrete_actions is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
if self.discrete_critic is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
else:
discrete_action = torch.ones(
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
"""Actor forward pass: sample actions and return log-probabilities.
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
batch: A flat observation dict, or a training dict containing
``"state"`` (observations) and optionally ``"observation_feature"``
(pre-computed encoder features).
Returns:
Tensor of Q-values from all critics
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = GaussianActorObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
"""Initialize policy actor network."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
@@ -455,21 +130,25 @@ class SACPolicy(
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_discrete_critic(self) -> None:
"""Initialize discrete critic network."""
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
# TODO(Khalil): Compile the discrete critic
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class SACObservationEncoder(nn.Module):
class GaussianActorObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: GaussianActorConfig) -> None:
super().__init__()
self.config = config
self._init_image_layers()
@@ -677,84 +356,6 @@ class MLP(nn.Module):
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class DiscreteCritic(nn.Module):
def __init__(
self,
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
class Policy(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
encoder: GaussianActorObservationEncoder,
network: nn.Module,
action_dim: int,
std_min: float = -5,
@@ -811,7 +412,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder: SACObservationEncoder = encoder
self.encoder: GaussianActorObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.std_min = std_min
@@ -885,7 +486,7 @@ class Policy(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
@@ -32,18 +32,18 @@ from lerobot.processor import (
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sac import SACConfig
from .configuration_gaussian_actor import GaussianActorConfig
def make_sac_pre_post_processors(
config: SACConfig,
def make_gaussian_actor_pre_post_processors(
config: GaussianActorConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
config: The configuration object for the tanh-Gaussian policy.
dataset_stats: A dictionary of statistics for normalization.
Returns:
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
-7
View File
@@ -95,13 +95,6 @@ from .relative_action_processor import (
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
# `lerobot.datasets.language`, which requires the `[dataset]` extra
# (`datasets`, `pyarrow`). Importing it from the processor package would
# break every base-install consumer of `lerobot.processor`. Users that
# need it import directly:
# from lerobot.processor.render_messages_processor import RenderMessagesStep
__all__ = [
"ActionProcessorStep",
"AddTeleopActionAsComplimentaryDataStep",
-18
View File
@@ -174,24 +174,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(
+16 -20
View File
@@ -153,30 +153,26 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
return x
_COMPLEMENTARY_KEYS = (
"task",
"index",
"task_index",
"episode_index",
"timestamp",
"language_persistent",
"language_events",
"messages",
"message_streams",
"target_message_indices",
)
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""Extract complementary data from a batch dictionary.
"""
Extract complementary data from a batch dictionary.
Includes padding flags (any key containing ``_is_pad``) plus the fixed
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS``
each only when present in ``batch``.
This includes padding flags, task description, and indices.
Args:
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
return {**pad_keys, **extras}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
+31 -11
View File
@@ -4,7 +4,6 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
@@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
3. Copying `discrete_penalty` from `info` to `complementary_data`.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if DISCRETE_PENALTY_KEY in info:
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
@@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
Applies a small per-transition cost on the discrete gripper action.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Fires only when the commanded action would actually transition the gripper
from one extreme to the other (close-while-open or open-while-closed).
This discourages gripper oscillation while leaving "stay" and saturating-further
commands unpenalized.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
open_threshold: Normalized state below which the gripper is considered "open".
closed_threshold: Normalized state above which the gripper is considered "closed".
"""
penalty: float = -0.01
penalty: float = -0.02
max_gripper_pos: float = 30.0
open_threshold: float = 0.1
closed_threshold: float = 0.9
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
@@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep):
if raw_joint_positions is None:
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None)
if current_gripper_pos is None:
return new_transition
# Gripper action is a PolicyAction at this stage
# During reset, the transition may not carry any action yet.
if action is None:
return new_transition
# Gripper action is expected as the last action dimension.
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
@@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
# - currently open AND target is closed -> close transition
# - currently closed AND target is open -> open transition
is_open = gripper_state_normalized < self.open_threshold
is_closed = gripper_state_normalized > self.closed_threshold
cmd_close = gripper_action_normalized > self.closed_threshold
cmd_open = gripper_action_normalized < self.open_threshold
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
@@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
A dictionary containing the penalty value, max gripper position,
and the open/closed thresholds.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
"open_threshold": self.open_threshold,
"closed_threshold": self.closed_threshold,
}
def reset(self) -> None:
@@ -134,6 +134,24 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
(already ``(C, 1, 1)``). Needed by RL training, which can start without
a dataset and supplies stats manually via JSON config.
"""
for key, feature in self.features.items():
if feature.type != FeatureType.VISUAL:
continue
if key not in self._tensor_stats:
continue
for stat_name, stat_tensor in self._tensor_stats[key].items():
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
continue
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -152,6 +170,7 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -201,6 +220,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -211,6 +231,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
@@ -1,94 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value
@@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
model_name: str = "lerobot/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
@@ -105,6 +105,7 @@ class Classifier(PreTrainedRewardModel):
def __init__(
self,
config: RewardClassifierConfig,
**kwargs,
):
from transformers import AutoModel
+26 -16
View File
@@ -12,23 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reinforcement learning modules.
"""Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'``
Available modules (import directly)::
from lerobot.rl.actor import ...
from lerobot.rl.learner import ...
from lerobot.rl.learner_service import ...
from lerobot.rl.buffer import ...
from lerobot.rl.eval_policy import ...
from lerobot.rl.gym_manipulator import ...
Distributed actor / learner entry points (``actor``, ``learner``,
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
buffer, data sources and trainer are gRPC-free and usable standalone.
"""
from lerobot.utils.import_utils import require_package
from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
from .algorithms.factory import (
make_algorithm as make_algorithm,
make_algorithm_config as make_algorithm_config,
)
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
from .trainer import RLTrainer as RLTrainer
require_package("grpcio", extra="hilserl", import_name="grpc")
__all__: list[str] = []
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"make_algorithm",
"make_algorithm_config",
"SACAlgorithmConfig",
"RLTrainer",
"ReplayBuffer",
"DataMixer",
"OnlineOfflineMixer",
]
+113 -78
View File
@@ -49,39 +49,53 @@ https://github.com/michel-aractingi/lerobot-hilserl-guide
import logging
import os
import time
from collections.abc import Generator
from functools import lru_cache
from queue import Empty
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
@@ -89,19 +103,24 @@ from lerobot.utils.utils import (
init_logging,
)
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
reset_and_build_transition,
step_env_and_process_transition,
)
from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig
# Main entry point
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate()
display_pid = False
if not use_threads(cfg):
@@ -212,7 +231,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
@@ -252,22 +271,24 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### To avoid sending a policy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
policy = policy.to(device).eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -291,8 +312,17 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
# Unnormalize only the continuous part.
if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to(
device=continuous_action.device, dtype=continuous_action.dtype
)
action = torch.cat([continuous_action, discrete_action], dim=-1)
else:
action = postprocessor.process_action(action)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -326,7 +356,8 @@ def act_with_policy(
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
if is_intervention:
episode_intervention = True
episode_intervention_steps += 1
@@ -334,6 +365,7 @@ def act_with_policy(
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
TeleopEvents.IS_INTERVENTION.value: is_intervention,
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
@@ -354,7 +386,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -390,14 +422,7 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
@@ -408,10 +433,10 @@ def act_with_policy(
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event
attempts: int = 30,
):
) -> bool:
"""Establish a connection with the learner.
Args:
@@ -441,12 +466,14 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""
Returns a client for the learner service.
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
Returns:
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
"""
channel = grpc.insecure_channel(
@@ -461,16 +488,18 @@ def learner_service_client(
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
@@ -513,12 +542,11 @@ def receive_policy(
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send transitions to the learner.
This function continuously retrieves messages from the queue and processes:
@@ -526,6 +554,13 @@ def send_transitions(
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
transitions_queue (Queue): The queue to receive the transitions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -563,18 +598,24 @@ def send_transitions(
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
interactions_queue (Queue): The queue to receive the interactions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -613,7 +654,11 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -629,10 +674,10 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
def interactions_stream(
shutdown_event: Event,
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float, # type: ignore
) -> services_pb2.Empty:
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
@@ -652,7 +697,8 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
@@ -667,18 +713,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
algorithm.load_weights(state_dicts, device=device)
# Utilities functions
+20
View File
@@ -0,0 +1,20 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sac import SACAlgorithm, SACAlgorithmConfig
__all__ = [
"SACAlgorithm",
"SACAlgorithmConfig",
]
+207
View File
@@ -0,0 +1,207 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import os
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from torch.optim import Optimizer
from lerobot.types import BatchType
from lerobot.utils.hub import HubMixin
from .configs import RLAlgorithmConfig, TrainingStats
if TYPE_CHECKING:
from torch import nn
from ..data_sources.data_mixer import DataMixer
T = TypeVar("T", bound="RLAlgorithm")
class RLAlgorithm(HubMixin, abc.ABC):
"""Base for all RL algorithms."""
config_class: type[RLAlgorithmConfig]
name: str
config: RLAlgorithmConfig
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One complete training step.
The algorithm calls ``next(batch_iterator)`` as many times as it
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
raise NotImplementedError
def configure_data_iterator(
self,
data_mixer: DataMixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
) -> Iterator[BatchType]:
"""Create the data iterator this algorithm needs.
The default implementation uses the standard ``data_mixer.get_iterator()``.
Algorithms that need specialised sampling should override this method.
"""
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
@abc.abstractmethod
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""Build and return the optimizers used during training.
Called once on the learner side after construction.
"""
raise NotImplementedError
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
return {}
@property
def optimization_step(self) -> int:
"""Current learner optimization step.
Part of the stable contract for checkpoint/resume. Algorithms can
either use this default storage or override for custom behavior.
"""
return getattr(self, "_optimization_step", 0)
@optimization_step.setter
def optimization_step(self, value: int) -> None:
self._optimization_step = int(value)
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors."""
return {}
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Must return a flat tensor mapping for everything the algorithm owns
that is not part of the policy (e.g. critic ensembles, target networks,
temperature parameters). Algorithms with no training-only tensors
should explicitly return an empty dict.
"""
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
Implementations MUST keep the identity of any ``nn.Parameter`` that an
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
rather than rebinding the attribute.
"""
raise NotImplementedError
def _save_pretrained(self, save_directory: Path) -> None:
"""Persist the algorithm's tensors and config to ``save_directory``.
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
"""
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
self.config._save_pretrained(save_directory)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
policy: nn.Module,
config: RLAlgorithmConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
device: str | torch.device = "cpu",
**algo_kwargs: Any,
) -> T:
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
if config is None:
config = cls.config_class.from_pretrained(
pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if hasattr(config, "policy_config"):
config.policy_config = policy.config
instance = cls(policy=policy, config=config, **algo_kwargs)
model_id = str(pretrained_name_or_path)
if os.path.isdir(model_id):
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
tensors = load_safetensors(model_file)
instance.load_state_dict(tensors, device=device)
return instance
+138
View File
@@ -0,0 +1,138 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RLAlgorithmConfig")
logger = logging.getLogger(__name__)
@dataclass
class TrainingStats:
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
losses: dict[str, float] = field(default_factory=dict)
grad_norms: dict[str, float] = field(default_factory=dict)
extra: dict[str, float] = field(default_factory=dict)
def to_log_dict(self) -> dict[str, float]:
"""Flatten all stats into a single dict for logging."""
d: dict[str, float] = {}
for name, val in self.losses.items():
d[name] = val
for name, val in self.grad_norms.items():
d[f"{name}_grad_norm"] = val
for name, val in self.extra.items():
d[name] = val
return d
@dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Registry for algorithm configs."""
@property
def type(self) -> str:
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@classmethod
@abc.abstractmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
"""Build an algorithm config from a policy config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
def _save_pretrained(self, save_directory: Path) -> None:
"""Serialize this config as ``config.json`` inside ``save_directory``."""
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**algo_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with draccus.config_type("json"):
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
raise TypeError(
f"Config at {model_id} has type '{instance.type}' but was loaded via "
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
)
for key, value in algo_kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
return instance
+99
View File
@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from .base import RLAlgorithm
from .configs import RLAlgorithmConfig
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
Args:
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
**kwargs: Keyword arguments forwarded to the config class constructor.
Returns:
An instance of the matching ``RLAlgorithmConfig`` subclass.
Raises:
ValueError: If ``algorithm_type`` is not registered.
"""
try:
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
except KeyError as err:
raise ValueError(
f"Algorithm type '{algorithm_type}' is not registered. "
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
"""
Retrieves an RL algorithm class by its registered name.
This function uses dynamic imports to avoid loading all algorithm classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the algorithm. Supported names are "sac".
Returns:
The algorithm class corresponding to the given name.
Raises:
ValueError: If the algorithm name is not recognized.
"""
if name == "sac":
from .sac.sac_algorithm import SACAlgorithm
return SACAlgorithm
raise ValueError(
f"Algorithm type '{name}' is not available. "
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
)
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
"""
Instantiate an RL algorithm.
This factory function looks up the :class:`RLAlgorithm` subclass that matches
``cfg.type`` and instantiates it with the provided policy. It also enforces
that ``cfg.policy_config`` has been populated before construction (this is
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
Args:
cfg: The algorithm configuration. Must have ``policy_config`` set.
policy: The policy module the algorithm will train.
Returns:
An instantiated :class:`RLAlgorithm`.
Raises:
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
registered.
"""
if getattr(cfg, "policy_config", None) is None:
raise ValueError(
f"{type(cfg).__name__}.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling make_algorithm()."
)
cls = get_algorithm_class(cfg.type)
return cls(policy=policy, config=cfg)
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACAlgorithmConfig
from .sac_algorithm import SACAlgorithm
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig,
GaussianActorConfig,
)
from ..configs import RLAlgorithmConfig
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
class SACAlgorithmConfig(RLAlgorithmConfig):
"""Soft Actor-Critic (SAC) algorithm configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum
entropy reinforcement learning framework. It learns a policy and a Q-function
simultaneously using experience collected from the environment.
This configuration class contains the algorithm-side hyperparameters: critic
ensemble, target networks, temperature / entropy tuning, and the Bellman
update loop. The policy-side (actor + observation encoder) lives in
:class:`~lerobot.policies.gaussian_actor.GaussianActorConfig` and is
referenced via :attr:`policy_config`.
"""
# Optimizer learning rates
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Bellman update
# Discount factor for the SAC algorithm
discount: float = 0.99
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Critic ensemble
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Temperature / entropy
# Initial temperature value
temperature_init: float = 1.0
# Target entropy for automatic temperature tuning. If ``None``, defaults to
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
# there is a discrete action head).
target_entropy: float | None = None
# Update loop
# Update-to-data ratio. Set to >1 to enable extra critic updates per env step.
utd_ratio: int = 1
# Frequency of policy updates
policy_update_freq: int = 1
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Optimizations
# torch.compile is currently disabled by default
use_torch_compile: bool = False
# Policy config
policy_config: PreTrainedConfig | None = None
@classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config with default hyperparameters for a given policy."""
return cls(
policy_config=policy_cfg,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
)
@@ -0,0 +1,672 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from collections.abc import Callable, Iterator
from dataclasses import asdict
from typing import Any
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
DISCRETE_DIMENSION_INDEX,
MLP,
DiscreteCritic,
GaussianActorObservationEncoder,
GaussianActorPolicy,
orthogonal_init,
)
from lerobot.policies.utils import get_device_from_parameters
from lerobot.types import BatchType
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import move_state_dict_to_device
from ..base import RLAlgorithm
from ..configs import TrainingStats
from .configuration_sac import SACAlgorithmConfig
class SACAlgorithm(RLAlgorithm):
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
config_class = SACAlgorithmConfig
name = "sac"
def __init__(
self,
policy: GaussianActorPolicy,
config: SACAlgorithmConfig,
):
self.config = config
self.policy_config = config.policy_config
self.policy = policy
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim)
self._init_temperature(action_dim)
self._device = torch.device(self.policy.config.device)
self._move_to_device()
def _init_critics(self, action_dim) -> None:
"""Build critic ensemble, targets."""
encoder = self.policy.encoder_critic
heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
target_heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
# TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.discrete_critic_target = None
if self.policy_config.num_discrete_actions is not None:
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
"""Build target discrete critic (main network is owned by the policy)."""
discrete_critic_target = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO(Khalil): Compile the discrete critic
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
return discrete_critic_target
def _init_temperature(self, continuous_action_dim: int) -> None:
"""Set up temperature parameter (log_alpha) and target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
total_action_dim = continuous_action_dim + (
1 if self.policy_config.num_discrete_actions is not None else 0
)
self.target_entropy = -total_action_dim / 2
def _move_to_device(self) -> None:
self.policy.to(self._device)
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if self.discrete_critic_target is not None:
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def _critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def _discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Run one SAC training step (critic / discrete-critic / actor / temperature).
Pulls ``utd_ratio`` batches from ``batch_iterator``, computes the relevant
losses, backpropagates each, and updates target networks.
Args:
batch_iterator: yields batches each containing
- ``action``: Action tensor
- ``reward``: Reward tensor
- ``state``: Observations tensor dict
- ``next_state``: Next observations tensor dict
- ``done``: Done mask tensor
- ``observation_feature``: Optional pre-computed observation features
- ``next_observation_feature``: Optional pre-computed next observation features
- ``complementary_info`` (optional): per-step extras like discrete penalties
Returns:
TrainingStats with per-component losses and grad norms.
"""
clip = self.config.grad_clip_norm
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
self.optimizers["critic"].step()
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
self.optimizers["discrete_critic"].step()
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
self.optimizers["critic"].step()
stats = TrainingStats(
losses={"loss_critic": loss_critic.item()},
grad_norms={"critic": critic_grad},
)
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(), max_norm=clip
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_dc.item()
stats.grad_norms["discrete_critic"] = dc_grad
if self._optimization_step % self.config.policy_update_freq == 0:
for _ in range(self.config.policy_update_freq):
loss_actor = self._compute_loss_actor(fb)
self.optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad = torch.nn.utils.clip_grad_norm_(
self.policy.actor.parameters(), max_norm=clip
).item()
self.optimizers["actor"].step()
loss_temp = self._compute_loss_temperature(fb)
self.optimizers["temperature"].zero_grad()
loss_temp.backward()
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
self.optimizers["temperature"].step()
stats.losses["loss_actor"] = loss_actor.item()
stats.losses["loss_temperature"] = loss_temp.item()
stats.grad_norms["actor"] = actor_grad
stats.grad_norms["temperature"] = temp_grad
stats.extra["temperature"] = self.temperature
self._update_target_networks()
self._optimization_step += 1
return stats
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
# Extract common components from batch
observations = batch["state"]
actions = batch[ACTION]
observation_features = batch.get("observation_feature")
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.policy.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self._critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.policy_config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self._critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
observation_features = batch.get("observation_feature")
next_observation_features = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self._discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self._discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self._discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
observation_features = batch.get("observation_feature")
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
q_preds = self._critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
"""Compute the temperature loss"""
observations = batch["state"]
observation_features = batch.get("observation_feature")
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.policy.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def _update_target_networks(self) -> None:
"""Update target networks with exponential moving average"""
for target_p, p in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
def _prepare_forward_batch(
self, batch: BatchType, *, include_complementary_info: bool = True
) -> dict[str, Any]:
observations = batch["state"]
next_observations = batch["next_state"]
observation_features, next_observation_features = self.get_observation_features(
observations, next_observations
)
forward_batch: dict[str, Any] = {
ACTION: batch[ACTION],
"reward": batch["reward"],
"state": observations,
"next_state": next_observations,
"done": batch["done"],
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
if include_complementary_info and "complementary_info" in batch:
forward_batch["complementary_info"] = batch["complementary_info"]
return forward_batch
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
A dictionary mapping component names ("actor", "critic", "temperature")
to their respective Adam optimizers.
"""
actor_params = self.policy.get_optim_params()["actor"]
self.optimizers = {
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
}
if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
def get_weights(self) -> dict[str, Any]:
"""Send actor + discrete-critic state dicts."""
state_dicts: dict[str, Any] = {
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
}
if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
self.policy.discrete_critic.state_dict(), device="cpu"
)
return state_dicts
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy."""
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Encoder weights are stripped because they are owned by the policy
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
"""
bundle: dict[str, torch.Tensor] = {}
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
bundle[f"critic_ensemble.{k}"] = v
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
bundle[f"critic_target.{k}"] = v
if self.discrete_critic_target is not None:
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
bundle[f"discrete_critic_target.{k}"] = v
bundle["log_alpha"] = self.log_alpha.detach()
return bundle
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
``temperature`` optimizer's reference to the parameter object stays
valid after resume.
"""
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
critic_target_state = _split_prefix(state_dict, "critic_target.")
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
self.critic_target.load_state_dict(critic_target_state, strict=False)
if self.discrete_critic_target is not None:
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
if "log_alpha" in state_dict:
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Drop ``encoder.*`` keys from a critic-module state dict."""
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (GaussianActorObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: GaussianActorObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
+3 -3
View File
@@ -97,8 +97,8 @@ class ReplayBuffer:
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
state_keys (list[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Callable | None): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
@@ -634,7 +634,7 @@ class ReplayBuffer:
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
transitions (list[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
+2 -2
View File
@@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
crop_params_dict (dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
resize_size (tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.types import BatchType
from .data_mixer import DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
+97
View File
@@ -0,0 +1,97 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from lerobot.types import BatchType
from ..buffer import ReplayBuffer, concatenate_batch_transitions
class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies."""
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
raise NotImplementedError
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Infinite iterator that yields batches."""
while True:
yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an offline replay buffer."""
def __init__(
self,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer | None = None,
online_ratio: float = 1.0,
):
if not 0.0 <= online_ratio <= 1.0:
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
self.online_buffer = online_buffer
self.offline_buffer = offline_buffer
self.online_ratio = online_ratio
def sample(self, batch_size: int) -> BatchType:
if self.offline_buffer is None:
return self.online_buffer.sample(batch_size)
n_online = max(1, int(batch_size * self.online_ratio))
n_offline = batch_size - n_online
online_batch = self.online_buffer.sample(n_online)
offline_batch = self.offline_buffer.sample(n_offline)
return concatenate_batch_transitions(online_batch, offline_batch)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches by composing buffer async iterators."""
n_online = max(1, int(batch_size * self.online_ratio))
online_iter = self.online_buffer.get_iterator(
batch_size=n_online,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
if self.offline_buffer is None:
yield from online_iter
return
n_offline = batch_size - n_online
offline_iter = self.offline_buffer.get_iterator(
batch_size=n_offline,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
while True:
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
+1 -1
View File
@@ -17,7 +17,6 @@ import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_policy
from lerobot.robots import ( # noqa: F401
@@ -31,6 +30,7 @@ from lerobot.teleoperators import (
)
from .gym_manipulator import make_robot_env
from .train_rl import TrainRLServerPipelineConfig
logging.basicConfig(level=logging.INFO)
+108 -73
View File
@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.import_utils import require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -312,6 +313,7 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
# Check if this is a GymHIL simulation environment
if cfg.name == "gym_hil":
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
import gym_hil # noqa: F401
# Extract gripper settings with defaults
@@ -383,10 +385,21 @@ def make_processors(
GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
env_pipeline_steps.extend(
[
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
)
return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
), DataProcessorPipeline(
@@ -551,8 +564,19 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
# Merge env and action-processor info: env wins for str keys, action-processor
# wins for `TeleopEvents` enum keys
action_info = processed_action_transition[TransitionKey.INFO]
new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO])
for key, value in action_info.items():
if isinstance(key, TeleopEvents):
new_info[key] = value
new_transition = create_transition(
observation=obs,
@@ -568,6 +592,24 @@ def step_env_and_process_transition(
return new_transition
def reset_and_build_transition(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
) -> EnvTransition:
"""Reset env + processors and return the first env-processed transition."""
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
complementary_data: dict[str, Any] = {}
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
return env_processor(data=transition)
def control_loop(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
@@ -593,17 +635,7 @@ def control_loop(
print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit")
# Reset environment and processors
obs, info = env.reset()
complementary_data = (
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
)
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(data=transition)
transition = reset_and_build_transition(env, env_processor, action_processor)
# Determine if gripper is used
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
@@ -659,79 +691,82 @@ def control_loop(
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
try:
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {
observation = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
# Use teleop_action if available, otherwise use the action from the transition
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
frame = {
**observations,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
if cfg.mode == "record":
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
)
frame = {
**observation,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"discrete_penalty", 0.0
)
frame["complementary_info.discrete_penalty"] = np.array(
[discrete_penalty], dtype=np.float32
)
episode_step += 1
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
episode_step += 1
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
# Reset for new episode
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# Reset for new episode
transition = reset_and_build_transition(env, env_processor, action_processor)
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
finally:
if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None:
logging.info("Waiting for image writer to finish...")
dataset.writer.image_writer.stop()
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Finalizing dataset before pushing to hub")
+123 -309
View File
@@ -51,9 +51,21 @@ import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2_grpc
else:
grpc = None
services_pb2_grpc = None
import grpc
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
@@ -68,14 +80,11 @@ from lerobot.common.train_utils import (
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
@@ -84,26 +93,35 @@ from lerobot.transport.utils import (
)
from lerobot.utils.constants import (
ACTION,
ALGORITHM_DIR,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
init_logging,
)
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg):
import torch.multiprocessing as mp
@@ -179,7 +197,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
def start_learner_threads(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
) -> None:
"""
Start the learner threads for training.
@@ -253,7 +271,7 @@ def start_learner_threads(
def add_actor_information_and_train(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
@@ -266,8 +284,8 @@ def add_actor_information_and_train(
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Delegates training updates to an ``RLAlgorithm``.
- Periodically pushes updated weights to actors.
- Logs training statistics, including loss values and optimization frequency.
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
@@ -286,17 +304,13 @@ def add_actor_information_and_train(
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -308,7 +322,7 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
@@ -317,15 +331,17 @@ def add_actor_information_and_train(
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
# Push initial policy weights to actors
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
@@ -338,21 +354,37 @@ def add_actor_information_and_train(
device=device,
storage_device=storage_device,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# DataMixer: online-only or online/offline 50-50 mix
data_mixer = OnlineOfflineMixer(
online_buffer=replay_buffer,
offline_buffer=offline_replay_buffer,
online_ratio=cfg.online_ratio,
)
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
preprocessor=preprocessor,
)
# If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
)
logging.info("Starting learner thread")
interaction_message = None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
algorithm.optimization_step = optimization_step
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -365,7 +397,6 @@ def add_actor_information_and_train(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
device=device,
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
@@ -382,180 +413,20 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["discrete_critic"].step()
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
critic_output = policy.forward(forward_batch, model="critic")
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["discrete_critic"].step()
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
stats = trainer.training_step()
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
# Update target networks (main and discrete)
policy.update_target_networks()
training_infos = stats.to_log_dict()
# Log training metrics at specified intervals
optimization_step = algorithm.optimization_step
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
@@ -583,7 +454,6 @@ def add_actor_information_and_train(
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
@@ -597,9 +467,12 @@ def add_actor_information_and_train(
policy=policy,
optimizers=optimizers,
replay_buffer=replay_buffer,
algorithm=algorithm,
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
@@ -607,7 +480,7 @@ def start_learner(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
cfg: TrainRLServerPipelineConfig,
):
"""
@@ -681,9 +554,12 @@ def save_training_checkpoint(
policy: nn.Module,
optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer,
algorithm: RLAlgorithm | None = None,
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
preprocessor=None,
postprocessor=None,
) -> None:
"""
Save training checkpoint and associated data.
@@ -707,6 +583,8 @@ def save_training_checkpoint(
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
preprocessor: Optional preprocessor pipeline to save
postprocessor: Optional postprocessor pipeline to save
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
@@ -715,7 +593,7 @@ def save_training_checkpoint(
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
# Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=optimization_step,
@@ -723,13 +601,22 @@ def save_training_checkpoint(
policy=policy,
optimizer=optimizers,
scheduler=None,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Algorithm-owned tensors live in their own component subfolder
# so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
if algorithm is not None:
algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
# Enrich training_step.json with the RL-specific interaction_step counter so
# both can be restored from a single file.
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
write_json(
{"step": optimization_step, "interaction_step": interaction_step},
training_state_dir / TRAINING_STEP,
)
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -760,58 +647,6 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler
# Training setup functions
@@ -875,13 +710,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
def load_training_state(
cfg: TrainRLServerPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer],
algorithm: RLAlgorithm | None = None,
device: str | torch.device = "cpu",
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Loads the training state (optimizers, RNG, step + interaction step, and
algorithm-owned tensors) from the most recent checkpoint.
Args:
cfg (TrainRLServerPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
cfg: Training configuration.
optimizers: Optimizers to load state into.
algorithm: Algorithm whose state dict should be restored.
Required for full main-equivalent resume;
the policy itself is restored separately via ``make_policy``.
device: Device on which to place loaded algorithm tensors.
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
@@ -890,20 +732,31 @@ def load_training_state(
return None, None
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Restore optimizers + RNG + step from the standard `training_state/` folder
step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
interaction_step = training_state.get("interaction_step", 0)
# Restore algorithm-owned tensors
if algorithm is not None:
algo_dir = checkpoint_dir / ALGORITHM_DIR
if algo_dir.is_dir():
tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
algorithm.load_state_dict(tensors, device=device)
logging.info(f"Loaded algorithm state from {algo_dir}")
else:
logging.warning(
f"No algorithm state found at {algo_dir}; "
"will keep their freshly-initialised values. Adam moments restored from the "
"old optimizer state may not match these reset parameters."
)
# Read interaction_step from the enriched training_step.json
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
@@ -1016,33 +869,6 @@ def initialize_offline_replay_buffer(
# Utilities/Helpers functions
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
@@ -1093,19 +919,11 @@ def check_nan_in_transition(
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None:
logging.debug("[LEARNER] Pushing actor policy to the queue")
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_dicts = algorithm.get_weights()
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
@@ -1129,9 +947,8 @@ def process_transitions(
transition_queue: Queue,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
device: str,
dataset_repo_id: str | None,
shutdown_event: any,
shutdown_event: Any, # Event
):
"""Process all available transitions from the queue.
@@ -1139,7 +956,6 @@ def process_transitions(
transition_queue: Queue for receiving transitions from the actor
replay_buffer: Replay buffer to add transitions to
offline_replay_buffer: Offline replay buffer to add transitions to
device: Device to move transitions to
dataset_repo_id: Repository ID for dataset
shutdown_event: Event to signal shutdown
"""
@@ -1148,8 +964,6 @@ def process_transitions(
transition_list = bytes_to_transitions(buffer=transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition=transition, device=device)
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
@@ -1163,7 +977,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
TeleopEvents.IS_INTERVENTION
TeleopEvents.IS_INTERVENTION.value
):
offline_replay_buffer.add(**transition)
@@ -1172,7 +986,7 @@ def process_interaction_messages(
interaction_message_queue: Queue,
interaction_step_shift: int,
wandb_logger: WandBLogger | None,
shutdown_event: any,
shutdown_event: Any, # Event
) -> dict | None:
"""Process all available interaction messages from the queue.
+24 -7
View File
@@ -18,17 +18,32 @@
import logging
import time
from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.import_utils import _grpc_available
from .queue import get_last_item_from_queue
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
class LearnerService(_ServicerBase):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -51,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802
def StreamParameters( # noqa: N802
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
@@ -86,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
@@ -100,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
@@ -114,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
return services_pb2.Empty()
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Top-level pipeline config for distributed RL training (actor / learner)."""
from __future__ import annotations
from dataclasses import dataclass
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from .algorithms.configs import RLAlgorithmConfig
from .algorithms.factory import make_algorithm_config
from .algorithms.sac import SACAlgorithmConfig # noqa: F401
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm config.
algorithm: RLAlgorithmConfig | None = None
# Data mixer strategy name. Currently supports "online_offline".
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer.
online_ratio: float = 0.5
def validate(self) -> None:
super().validate()
if self.algorithm is None:
self.algorithm = make_algorithm_config("sac")
if getattr(self.algorithm, "policy_config", None) is None:
self.algorithm.policy_config = self.policy
+101
View File
@@ -0,0 +1,101 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from lerobot.types import BatchType
from .algorithms.base import RLAlgorithm
from .algorithms.configs import TrainingStats
from .data_sources.data_mixer import DataMixer
class RLTrainer:
"""Unified training step orchestrator.
Holds the algorithm, a DataMixer, and an optional preprocessor.
"""
def __init__(
self,
algorithm: RLAlgorithm,
data_mixer: DataMixer,
batch_size: int,
*,
preprocessor: Any | None = None,
):
self.algorithm = algorithm
self.data_mixer = data_mixer
self.batch_size = batch_size
self._preprocessor = preprocessor
self._iterator: Iterator[BatchType] | None = None
self.algorithm.make_optimizers_and_scheduler()
def _build_data_iterator(self) -> Iterator[BatchType]:
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
raw = self.algorithm.configure_data_iterator(
data_mixer=self.data_mixer,
batch_size=self.batch_size,
)
if self._preprocessor is not None:
return _PreprocessedIterator(raw, self._preprocessor)
return raw
def reset_data_iterator(self) -> None:
"""Discard the current iterator so it will be rebuilt lazily next step."""
self._iterator = None
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
"""Swap the active data mixer, optionally resetting the iterator."""
self.data_mixer = data_mixer
if reset:
self.reset_data_iterator()
def training_step(self) -> TrainingStats:
"""Run one training step (algorithm-agnostic)."""
if self._iterator is None:
self._iterator = self._build_data_iterator()
return self.algorithm.update(self._iterator)
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
"""Apply policy preprocessing to RL observations only."""
observations = batch["state"]
next_observations = batch["next_state"]
batch["state"] = preprocessor.process_observation(observations)
batch["next_state"] = preprocessor.process_observation(next_observations)
return batch
class _PreprocessedIterator:
"""Iterator wrapper that preprocesses each sampled RL batch."""
__slots__ = ("_raw", "_preprocessor")
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
self._raw = raw_iterator
self._preprocessor = preprocessor
def __iter__(self) -> _PreprocessedIterator:
return self
def __next__(self) -> BatchType:
batch = next(self._raw)
return preprocess_rl_batch(self._preprocessor, batch)
@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
discrete_gripper: If True, interpret the input as a discrete class index
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
"""
speed_factor: float = 20.0
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
raise ValueError("Joints observation is require for computing robot kinematics")
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_vel = (gripper_vel - 1) * self.clip_max
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
# Negation accounts for SO100 sign (joint position increases on close).
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
gripper_vel = -(gripper_vel - 1) * self.clip_max
# Compute desired gripper position
delta = gripper_vel * float(self.speed_factor)
+1
View File
@@ -92,6 +92,7 @@ def get_sys_info() -> dict[str, str]:
info.update(
{
"PyTorch version": torch_version,
"Torchcodec version": get_package_version("torchcodec"),
"Is PyTorch built with CUDA support?": str(torch_cuda_available),
"Cuda version": cuda_version,
"GPU model": gpu_model,
-6
View File
@@ -48,7 +48,6 @@ from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -402,10 +401,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
shuffle = True
sampler = None
# Only swap in the language-aware collate when the dataset actually
# declares language columns; otherwise stay on PyTorch's default
# collate so non-language training runs are unaffected.
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
@@ -414,7 +409,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
@@ -104,11 +104,14 @@ class KeyboardTeleop(Teleoperator):
def _on_press(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, True))
key = key.char
self.event_queue.put((key, True))
def _on_release(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, False))
key = key.char
self.event_queue.put((key, False))
if key == keyboard.Key.esc:
logging.info("ESC pressed, disconnecting.")
self.disconnect()
@@ -204,8 +207,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
# this is useful for retrieving other events like interventions for RL, episode success, etc.
self.misc_keys_queue.put(key)
self.current_pressed.clear()
action_dict = {
"delta_x": delta_x,
"delta_y": delta_y,
@@ -256,6 +257,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
]
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
self.current_pressed.clear()
# Check for episode control commands from misc_keys_queue
terminate_episode = False
success = False
@@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "sac" %}
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
{% elif model_name == "gaussian_actor" %}
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
{% elif model_name == "reward_classifier" %}
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% else %}
+1
View File
@@ -40,6 +40,7 @@ PolicyAction = torch.Tensor
RobotAction = dict[str, Any]
EnvAction = np.ndarray
RobotObservation = dict[str, Any]
BatchType = dict[str, Any]
EnvTransition = TypedDict(
-65
View File
@@ -1,65 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
# carry `messages` and half don't) is a real bug in the upstream
# rendering step — silently filtering would hand downstream consumers a
# preserved list shorter than the tensor batch. Raise instead so the
# mismatch surfaces at the boundary.
preserved: dict[str, list[Any]] = {}
for key in _PYTHON_LIST_KEYS:
presence = [key in sample for sample in batch]
if not any(presence):
continue
if not all(presence):
raise ValueError(
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
f"every sample in a batch must agree."
)
preserved[key] = [sample[key] for sample in batch]
tensorizable = [
{
key: value
for key, value in sample.items()
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
}
for sample in batch
]
collated = default_collate(tensorizable)
collated.update(preserved)
return collated
+1
View File
@@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
ALGORITHM_DIR = "algorithm"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
+1
View File
@@ -132,6 +132,7 @@ _faker_available = is_package_available("faker")
_pynput_available = is_package_available("pynput")
_pygame_available = is_package_available("pygame")
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
_grpc_available = is_package_available("grpcio", import_name="grpc")
_wallx_deps_available = (
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
)
-168
View File
@@ -1,168 +0,0 @@
#!/usr/bin/env python
from pathlib import Path
from textwrap import dedent
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
return MessageTurn(role="user", content=content, stream="high_level")
def _minimal_target_turn() -> MessageTurn:
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
# ── Message-recipe validation ────────────────────────────────────────
def test_message_recipe_validates_unknown_binding():
with pytest.raises(ValueError, match="unknown binding"):
TrainingRecipe(
messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"),
_minimal_target_turn(),
]
)
def test_message_turn_requires_a_stream():
"""Every turn must declare a stream — None is rejected at construction.
Previously this only failed at render time (``_validate_rendered``);
catching it here means a malformed recipe YAML errors at load instead
of at the first training sample.
"""
with pytest.raises(ValueError, match="missing a stream"):
MessageTurn(role="user", content="${task}")
def test_message_recipe_requires_at_least_one_target():
with pytest.raises(ValueError, match="target"):
TrainingRecipe(
messages=[
_minimal_message_turn(),
MessageTurn(role="assistant", content="no target", stream="high_level"),
]
)
def test_recipe_rejects_both_messages_and_blend():
with pytest.raises(ValueError, match="only one"):
TrainingRecipe(
messages=[_minimal_message_turn(), _minimal_target_turn()],
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
)
def test_recipe_rejects_neither_messages_nor_blend():
with pytest.raises(ValueError, match="must set one"):
TrainingRecipe()
# ── Blend validation ─────────────────────────────────────────────────
def test_blend_must_be_non_empty():
with pytest.raises(ValueError, match="at least one component"):
TrainingRecipe(blend={})
def test_blend_component_must_define_weight():
with pytest.raises(ValueError, match="weight"):
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
def test_blend_component_weight_must_be_positive():
with pytest.raises(ValueError, match="positive weight"):
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
def test_blend_component_must_define_messages():
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
# going through __post_init__ to exercise the blend-level validator.
bad = TrainingRecipe.__new__(TrainingRecipe)
bad.messages = None
bad.bindings = None
bad.blend = None
bad.weight = 1.0
with pytest.raises(ValueError, match="must define messages"):
TrainingRecipe(blend={"a": bad})
def test_blend_components_cannot_themselves_define_a_blend():
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
# Force-bypass the inner component's normal validation so the test
# exercises the outer blend's "no nested blends" rule directly.
nested = TrainingRecipe.__new__(TrainingRecipe)
nested.messages = None
nested.bindings = None
nested.blend = inner.blend
nested.weight = 1.0
with pytest.raises(ValueError, match="cannot itself define a blend"):
TrainingRecipe(blend={"outer": nested})
# ── from_dict / from_yaml round-trips ────────────────────────────────
def test_from_dict_with_nested_blend():
recipe = TrainingRecipe.from_dict(
{
"blend": {
"a": {
"weight": 1.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
],
},
"b": {
"weight": 2.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
],
},
}
}
)
assert recipe.blend is not None
assert set(recipe.blend) == {"a", "b"}
assert recipe.blend["b"].weight == 2.0
# Inner messages were promoted to MessageTurn instances.
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
yaml_text = dedent(
"""
bindings:
custom: "active_at(t, style=subtask)"
messages:
- {role: user, content: "${task}: ${custom}", stream: high_level}
- {role: assistant, content: "ok", stream: high_level, target: true}
"""
).strip()
path = tmp_path / "recipe.yaml"
path.write_text(yaml_text)
via_classmethod = TrainingRecipe.from_yaml(path)
via_helper = load_recipe(path)
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
assert via_classmethod.messages[1].target is True
# ``load_recipe`` is just a wrapper, but assert the two paths agree
# on the structural result so a future divergence is caught here.
assert via_helper.bindings == via_classmethod.bindings
assert len(via_helper.messages) == len(via_classmethod.messages)
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
path = tmp_path / "bad.yaml"
path.write_text("- just\n- a\n- list\n")
with pytest.raises(ValueError, match="mapping at the top level"):
TrainingRecipe.from_yaml(path)
-81
View File
@@ -385,84 +385,3 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0
# ── Tools accessor ───────────────────────────────────────────────────
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "no_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
assert meta.tools == DEFAULT_TOOLS
# info.json on disk should NOT include a `tools` key for clean datasets
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
def test_tools_reads_declared_tools_from_info_json(tmp_path):
"""A `tools` list written into info.json survives load → meta.tools.
Regression test for the bug where ``DatasetInfo.from_dict`` silently
dropped the ``tools`` key (no matching dataclass field), so
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
what was on disk.
"""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "with_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/with_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
info_path = root / INFO_PATH
with open(info_path) as f:
raw = json.load(f)
raw["tools"] = [custom_tool]
with open(info_path, "w") as f:
json.dump(raw, f)
# Reload info from disk and rebind it on the metadata object
meta.info = load_info(root)
assert meta.tools == [custom_tool]
def test_tools_round_trip_through_dataset_info(tmp_path):
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
from lerobot.datasets.utils import DatasetInfo
raw = {
"codebase_version": "v3.1",
"fps": 30,
"features": SIMPLE_FEATURES,
"tools": [{"type": "function", "function": {"name": "say"}}],
}
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]
+65
View File
@@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
expected_frames = sum(lengths[i] for i in expected_eps)
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
assert filtered.num_frames == expected_frames
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
candidates = [0, 2, 4, 6]
candidate_lengths = [lengths[i] for i in candidates]
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
expected_eps = [i for i in candidates if lengths[i] >= threshold]
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=candidates,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
with pytest.raises(ValueError, match=r"The episode filter did not match any episode"):
LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] < 0,
)
def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
"""A predicate referencing a column absent from meta.episodes surfaces a clear KeyError."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
with pytest.raises(KeyError, match="not_a_real_field"):
LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
)
-156
View File
@@ -1,156 +0,0 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import numpy as np # noqa: E402
import pandas as pd # noqa: E402
import pyarrow as pa # noqa: E402
from lerobot.datasets import LeRobotDataset # noqa: E402
from lerobot.datasets.io_utils import write_info # noqa: E402
from lerobot.datasets.language import ( # noqa: E402
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style,
is_view_dependent_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
validate_camera_field,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH # noqa: E402
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
# motion lives in PERSISTENT_STYLES and is described in robot-frame
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
# (event) and trace (event, pixel-trajectory) carry a camera tag.
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("motion")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("trace", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("motion", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("motion", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"camera": "observation.images.top",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []
-417
View File
@@ -1,417 +0,0 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.datasets.language_render import ( # noqa: E402
EMITTED_AT_TOLERANCE_S,
active_at,
emitted_at,
nth_next,
nth_prev,
render_sample,
)
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"timestamp": timestamp,
"camera": camera,
"tool_calls": tool_calls,
}
def event_row(role, content, style, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"camera": camera,
"tool_calls": tool_calls,
}
PERSISTENT = [
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
# Same emission tick, two cameras: triggers per-camera disambiguation in
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
# (vqa, user) + (vqa, assistant) pair per camera.
EVENTS_AT_3_TWO_CAMERAS = [
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
]
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
def test_persistent_relative_resolvers_reject_event_styles():
with pytest.raises(ValueError, match="event-only"):
active_at(1.0, persistent=PERSISTENT, style="vqa")
with pytest.raises(ValueError, match="event-only"):
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
def test_nth_prev_and_next():
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
def test_substitution_if_present_multimodal_and_tool_calls():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${task}: ${interjection}"},
],
stream="high_level",
if_present="interjection",
),
MessageTurn(
role="assistant",
content="${plan}",
stream="high_level",
target=True,
tool_calls_from="speech",
),
],
bindings={"plan": "active_at(t, style=plan)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
assert rendered["messages"][1]["content"] == "plan 0"
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
assert rendered["message_streams"] == ["high_level", "high_level"]
assert rendered["target_message_indices"] == [1]
def test_exact_event_miss_returns_none_when_target_skips():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
]
)
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
recipe = TrainingRecipe(
blend={
"a": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
],
),
"b": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
],
),
}
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second
def test_emitted_at_filters_vqa_by_camera():
top = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.top",
)
wrist = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.wrist",
)
assert top["content"] == '{"count": 3}'
assert wrist["content"] == '{"count": 1}'
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
with pytest.raises(ValueError, match="Ambiguous resolver"):
emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
)
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
return TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
},
messages=[
MessageTurn(
role="user",
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
)
@pytest.mark.parametrize(
("camera", "expected_query", "expected_answer"),
[
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
],
)
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
rendered = render_sample(
recipe=_vqa_subrecipe(camera),
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
assert rendered["messages"][0]["content"][0]["feature"] == camera
assert rendered["messages"][0]["content"][1]["text"] == expected_query
assert rendered["messages"][1]["content"] == expected_answer
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
rephrasings = [
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
# No explicit task override → resolver consults persistent rows.
seen: set[str] = set()
for sample_idx in range(64):
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=sample_idx,
dataset_ctx={"task": "canonical kitchen task"},
)
seen.add(rendered["messages"][0]["content"])
# Every rephrasing should be reachable across enough samples.
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT, # no task_aug rows
events=[],
t=0.0,
sample_idx=0,
dataset_ctx={"task": "clean the kitchen"},
)
assert rendered["messages"][0]["content"] == "clean the kitchen"
def test_resolve_task_explicit_override_beats_rephrasings():
rephrasings = [
persistent_row("user", "rephrased one", "task_aug", 0.0),
persistent_row("user", "rephrased two", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=0,
task="explicit override wins",
dataset_ctx={"task": "canonical"},
)
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
line up with the parquet-stored timestamp.
"""
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
# Half a tolerance window — bit-different float, comfortably inside
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
assert inside is not None and inside["content"] == "memo"
# Just past the window — no match
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
assert outside is None
def test_render_sample_rejects_non_dict_language_rows():
"""``_normalize_rows`` must surface malformed inputs as TypeError.
A pipeline that hands the renderer a non-dict (e.g. a stray string)
is a real upstream bug silent skipping would let it propagate.
"""
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
with pytest.raises(TypeError, match="must be dictionaries"):
render_sample(
recipe=recipe,
persistent=["not a dict"],
events=[],
t=0.0,
sample_idx=0,
task="x",
)
def test_low_level_branch_renders_active_subtask():
low_level = TrainingRecipe(
blend={
"low": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(
role="user",
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
stream="high_level",
),
MessageTurn(
role="assistant",
content="${subtask}",
stream="low_level",
target=True,
),
],
)
}
)
rendered = render_sample(
recipe=low_level,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
assert rendered["message_streams"][-1] == "low_level"
assert rendered["target_message_indices"] == [1]
+193
View File
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
@@ -17,19 +17,19 @@
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
GaussianActorConfig,
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
config = SACConfig()
def test_gaussian_actor_config_default_initialization():
config = GaussianActorConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -55,9 +55,6 @@ def test_sac_config_default_initialization():
# Basic parameters
assert config.device == "cpu"
assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics
assert config.vision_encoder_name is None
@@ -66,6 +63,8 @@ def test_sac_config_default_initialization():
assert config.shared_encoder is True
assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
# Training parameters
assert config.online_steps == 1000000
@@ -73,20 +72,6 @@ def test_sac_config_default_initialization():
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults
expected_dataset_stats = {
@@ -105,11 +90,6 @@ def test_sac_config_default_initialization():
}
assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True
@@ -135,7 +115,6 @@ def test_sac_config_default_initialization():
assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig)
@@ -175,22 +154,22 @@ def test_concurrency_config():
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
num_critics=3,
latent_dim=128,
state_encoder_hidden_dim=128,
num_discrete_actions=3,
)
assert config.device == "cpu"
assert config.discount == 0.95
assert config.temperature_init == 0.5
assert config.num_critics == 3
assert config.latent_dim == 128
assert config.state_encoder_hidden_dim == 128
assert config.num_discrete_actions == 3
def test_validate_features():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -198,7 +177,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
config = GaussianActorConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -209,7 +188,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -0,0 +1,528 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from torch import Tensor, nn # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 42
set_seed(seed)
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_gaussian_actor_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
GaussianActorPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> GaussianActorConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> GaussianActorConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
return algorithm, policy
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
# squeeze(0) removes batch dim when batch_size==1
assert selected_action.shape[-1] == action_dim
def test_gaussian_actor_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
# Squeeze to unbatched (single observation)
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
selected_action = policy.select_action(observation_batch)
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
with torch.no_grad():
output = policy.forward(batch)
assert "action" in output
assert "log_prob" in output
assert "action_mean" in output
assert output["action"].shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
temp_loss = algorithm._compute_loss_temperature(forward_batch)
assert temp_loss.item() is not None
assert temp_loss.shape == ()
algorithm.optimizers["temperature"].zero_grad()
temp_loss.backward()
algorithm.optimizers["temperature"].step()
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape[-1] == action_dim
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_gaussian_actor_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_gaussian_actor_training_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
def test_gaussian_actor_training_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
config.num_discrete_actions = 5
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
assert discrete_critic_loss.shape == ()
algorithm.optimizers["discrete_critic"].zero_grad()
discrete_critic_loss.backward()
algorithm.optimizers["discrete_critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
# Policy.select_action now handles both continuous + discrete
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
assert selected_action.shape[-1] == continuous_action_dim + 1
def test_sac_algorithm_target_entropy():
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -5.0
def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -3.5
def test_sac_algorithm_temperature():
import math
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.temperature == pytest.approx(1.0)
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
assert algorithm.temperature == pytest.approx(0.1)
def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.critic_target_update_weight = 1.0
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in algorithm.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
algorithm._update_target_networks()
for p in algorithm.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data))
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.num_critics = num_critics
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
assert len(algorithm.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_gaussian_actor_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
with torch.no_grad():
with seeded_context(12):
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert torch.allclose(actions, loaded_actions)
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
state_dim = 10
action_dim = 6
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
-546
View File
@@ -1,546 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 42
set_seed(seed)
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
"""Create optimizers for the SAC policy."""
optimizer_actor = torch.optim.Adam(
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(),
lr=policy.config.critic_lr,
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha],
lr=policy.config.critic_lr,
)
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if has_discrete_action:
optimizers["discrete_critic"] = torch.optim.Adam(
params=policy.discrete_critic.parameters(),
lr=policy.config.critic_lr,
)
return optimizers
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
# Let make tests a little bit faster
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
# Let's check best candidates for pretrained encoders
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_sac_policy_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
def test_sac_policy_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 # the last action is discrete
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
num_discrete_actions = 5
config.num_discrete_actions = num_discrete_actions
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
policy.train()
optimizers = make_optimizers(policy, has_discrete_action=True)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
assert discrete_critic_loss.item() is not None
assert discrete_critic_loss.shape == ()
discrete_critic_loss.backward()
optimizers["discrete_critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, full_action_dim)
discrete_actions = selected_action[:, -1].long()
discrete_action_values = set(discrete_actions.tolist())
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
)
def test_sac_policy_with_default_entropy():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.target_entropy == -5.0
def test_sac_policy_default_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
policy = SACPolicy(config=config)
assert policy.target_entropy == -3.0
def test_sac_policy_with_predefined_entropy():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.target_entropy = -3.5
policy = SACPolicy(config=config)
assert policy.target_entropy == pytest.approx(-3.5)
def test_sac_policy_update_temperature():
"""Test that temperature property is always in sync with log_alpha."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
# Temperature property automatically reflects log_alpha changes
assert policy.temperature == pytest.approx(0.1)
def test_sac_policy_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
policy = SACPolicy(config=config)
policy.train()
for p in policy.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
policy.update_target_networks()
for p in policy.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data)), (
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
)
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
policy = SACPolicy(config=config)
policy.train()
assert len(policy.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
root = tmp_path / "test_sac_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
# They should be the same
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
assert torch.allclose(actor_loss, loaded_actor_loss)
assert torch.allclose(temperature_loss, loaded_temperature_loss)
assert torch.allclose(actions, loaded_actions)
@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
def create_default_config():
"""Create a default SAC configuration for testing."""
config = SACConfig()
config = GaussianActorConfig()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
}
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_sac_processor_normalization_modes():
def test_gaussian_actor_processor_normalization_modes():
"""Test that SAC processor correctly handles different normalization modes."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_cuda():
def test_gaussian_actor_processor_cuda():
"""Test SAC processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_accelerate_scenario():
def test_gaussian_actor_processor_accelerate_scenario():
"""Test SAC processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_sac_processor_multi_gpu():
def test_gaussian_actor_processor_multi_gpu():
"""Test SAC processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
assert processed[TransitionKey.ACTION.value].device == device
def test_sac_processor_without_stats():
def test_gaussian_actor_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
assert processed is not None
def test_sac_processor_save_and_load():
def test_gaussian_actor_processor_save_and_load():
"""Test saving and loading SAC processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_mixed_precision():
def test_gaussian_actor_processor_mixed_precision():
"""Test SAC processor with mixed precision."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
def test_sac_processor_batch_data():
def test_gaussian_actor_processor_batch_data():
"""Test SAC processor with batched data."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
def test_sac_processor_edge_cases():
def test_gaussian_actor_processor_edge_cases():
"""Test SAC processor with edge cases."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
preprocessor, _ = make_gaussian_actor_pre_post_processors(
config,
stats,
)
+14 -7
View File
@@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict():
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Verify that _tensor_stats are also correctly set to match the override stats
# Verify that _tensor_stats values match the override stats
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats
expected_tensor_stats = to_tensor(override_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
torch.testing.assert_close(
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
override_normalizer._tensor_stats[key][stat_name].squeeze(),
expected_tensor_stats[key][stat_name].squeeze(),
)
@@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally():
# Stats should now match the original stats (normal behavior)
# Check that all keys and values match
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats,
# so we squeeze before comparing values.
for key in original_stats:
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
for stat_name in original_stats[key]:
np.testing.assert_allclose(
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
)
actual = new_normalizer.stats[key][stat_name]
expected = original_stats[key][stat_name]
if hasattr(actual, "squeeze"):
actual = actual.squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
def test_stats_explicit_provided_flag_detection():
@@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict():
assert ACTION in new_normalizer.stats
# Check that values are correct (converted back from tensors)
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
# Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
@@ -1,60 +0,0 @@
#!/usr/bin/env python
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.processor.converters import create_transition # noqa: E402
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
from lerobot.types import TransitionKey # noqa: E402
def test_render_messages_step_noops_without_language_columns():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(complementary_data={"task": "do it"})
assert RenderMessagesStep(recipe)(transition) == transition
def test_render_messages_step_renders_and_drops_raw_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "do it",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [
{
"role": "assistant",
"content": "reach carefully",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
],
"language_events": [],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "language_persistent" not in data
assert "language_events" not in data
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
-13
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@@ -36,9 +35,6 @@ def test_classifier_output():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -80,9 +76,6 @@ def test_binary_classifier_with_default_params():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -122,9 +115,6 @@ def test_multiclass_classifier():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -141,9 +131,6 @@ def test_default_device():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.rewards.classifier.modeling_classifier import Classifier
+167 -4
View File
@@ -22,12 +22,14 @@ import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("grpc")
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.constants import OBS_STR
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -79,7 +81,7 @@ def cfg():
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
@@ -299,3 +301,164 @@ def test_end_to_end_parameters_flow(cfg, data_size):
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])
def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.transport.utils import state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
batch_size = 4
def batch_iterator():
while True:
yield {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"done": torch.zeros(batch_size),
"complementary_info": {},
}
stats = algorithm.update(batch_iterator())
assert "loss_critic" in stats.losses
# get_weights -> state_to_bytes round-trip
weights = algorithm.get_weights()
assert len(weights) > 0
serialized = state_to_bytes(weights)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
# RLTrainer with DataMixer
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.trainer import RLTrainer
replay_buffer = ReplayBuffer(
capacity=50,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for _ in range(50):
replay_buffer.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(action_dim),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=False,
truncated=False,
)
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
)
trainer_stats = trainer.training_step()
assert "loss_critic" in trainer_stats.losses
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
algorithm.make_optimizers_and_scheduler()
# Simulate initial push (same code path the learner now uses)
initial_weights = algorithm.get_weights()
initial_bytes = state_to_bytes(initial_weights)
# Simulate periodic push
periodic_weights = algorithm.get_weights()
periodic_bytes = state_to_bytes(periodic_weights)
initial_decoded = bytes_to_state_dict(initial_bytes)
periodic_decoded = bytes_to_state_dict(periodic_bytes)
assert initial_decoded.keys() == periodic_decoded.keys()
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
# Actor side: no optimizers
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
# select_action should work
obs = {OBS_STATE: torch.randn(state_dim)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
# Simulate receiving weights from learner
fake_weights = algorithm.get_weights()
algorithm.load_weights(fake_weights, device="cpu")
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
from lerobot.utils.constants import OBS_STATE # noqa: E402
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
buf = ReplayBuffer(
capacity=capacity,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for i in range(capacity):
buf.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(2),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=bool(i % 10 == 9),
truncated=False,
)
return buf
def test_online_only_mixer_sample():
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
batch = mixer.sample(batch_size=8)
assert batch["state"][OBS_STATE].shape[0] == 8
assert batch["action"].shape[0] == 8
assert batch["reward"].shape[0] == 8
def test_online_only_mixer_ratio_one():
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
def test_online_offline_mixer_sample():
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
online = _make_buffer(capacity=50)
offline = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(
online_buffer=online,
offline_buffer=offline,
online_ratio=0.5,
)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
assert batch["action"].shape[0] == 10
# 5 from online, 5 from offline (approx)
assert batch["reward"].shape[0] == 10
def test_online_offline_mixer_iterator():
"""get_iterator yields batches of the requested size."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
batch1 = next(it)
batch2 = next(it)
assert batch1["state"][OBS_STATE].shape[0] == 4
assert batch2["state"][OBS_STATE].shape[0] == 4
+1 -1
View File
@@ -20,7 +20,7 @@ from queue import Queue
import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402

Some files were not shown because too many files have changed in this diff Show More