Compare commits

..

16 Commits

Author SHA1 Message Date
Steven Palma 2438df1307 chore(dependencies): update uv.lock (#3561) 2026-05-12 21:20:26 +02:00
Caroline Pascal f218d5ab30 feat(episodes): adding support for metadata based episodes filtering (#3530)
* feat(episode filtering): adding support for episodes filtering at initialization time in LeRobotDataset

* test(tests): adding tests

* chore(format): formatting code

* feat(performance): improving implementation for better performances on big datasets

* chores(warning): improving warnings and errors for episodes filtering

* test(invalid key): adding test for invalid filtering key

* chore(format): formatting code
2026-05-12 20:44:11 +02:00
Steven Palma 04125492e4 fix(datasets): expand torchcodec platform coverage + rewrite pyav fallback for torchvision >0.26 (#3588)
* fix(deps): better versioning control for torchcodec

* refactor(video_utils): replace torchvision with pyav

* adding Torchcodec version to lerobot-info

* chore(benchmarks): delete video benchmark

---------

Co-authored-by: Maximellerbach <maxime.ellerbach@huggingface.co>
2026-05-12 16:59:11 +02:00
Khalil Meftah e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00
Steven Palma 26ff40ddd7 chore(deps): cap torch ceiling at <2.12, pin Linux wheels to cu128 (#3570)
* chore(deps): ceiling + cuda

* ci: bump cuda version docker image

* ci: add cpu wheel to release workflow

* chore(deps): update uv.lock

* docs: update installation with cuda note
2026-05-11 19:47:55 +02:00
Maxime Ellerbach 6d269b28c8 docs(omx): adding some examples and scripts (#3566)
* docs(omx): adding some examples and scripts

* cleaning up and reviewing the cli args

* adding __init__.py to example folder, adjusting the examples

* adding reference to pretrained act policy

* moving `.send_action` before `dataset.add_frame` for consistency

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adjusting docstring

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adressing hardcoded dataset fps

* removed init as it worked without

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-05-11 15:36:32 +02:00
Steven Palma b607c8458e docs: add policy & compute guide (#3534)
* docs(policy): contributing a policy guide

* docs(training): HW compute guide

* chore(docs): add to readme and index

* Apply suggestions from code review

Co-authored-by: Haoming Song <1847575517@qq.com>
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

* chore(docs): slight improvements

* refactor(docs): consolidate add policy docs

* chore(style): fix pre-commit

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Haoming Song <1847575517@qq.com>
2026-05-11 15:19:12 +02:00
Jash Shah 9e83510c99 fix(datasets): close file handle on VideoDecoder init failure in cache (#3542)
If VideoDecoder() raises during initialization, the fsspec file handle
was leaked since it was opened via __enter__() but never closed on the
exception path. Now explicitly closes the handle before re-raising.
2026-05-10 17:30:37 +02:00
Anthony Shoumikhin 1f7b03f5f2 chore(deps): allow torch 2.11/2.12 and fix autocast deprecation (#3435)
* chore(deps): allow torch 2.11/2.12 and fix autocast deprecation

- Bump torch to >=2.7,<2.13 (was <2.11), torchvision to <0.28 (was <0.26),
  and torchcodec to <0.13 (was <0.11) to allow installs against the latest
  stable torch 2.11 and the upcoming 2.12 line.
- Replace removed torch.get_autocast_gpu_dtype() with torch.get_autocast_dtype("cuda")
  in Florence2 and Qwen2.5-VL-MoE FlashAttention paths (the former is removed in 2.11+).
- Refresh uv.lock for the new resolution (torch 2.11.0+cu130, torchvision 0.26.0+cu130,
  torchcodec 0.11.1, full CUDA 13 stack).

Verified locally with `uv sync --locked` from a clean .venv and the lerobot
test suite (pytest -n 8 --dist=loadfile --timeout=300). Failure set is
identical to the pre-bump baseline: 18 pre-existing failures
(test_sac_policy*, test_pi0_rtc*, test_pi05_rtc*, test_replay_buffer*),
0 new, 0 fixed.

AI assistance: this change was authored with Claude Code per AI_POLICY.md.

* fix(policies): use device-agnostic autocast dtype lookup

Pass query_states.device.type to torch.get_autocast_dtype() instead of
hardcoding 'cuda', so the cast matches the active autocast context when
running under CPU/MPS/XPU autocast.

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-10 13:05:35 +02:00
Steven Palma cb8edf17e6 chore(dependencies): update uv.lock (#3475) 2026-05-10 12:24:22 +02:00
Steven Palma 5699f6cbf4 chore(ci): disable auto-stale (#3550) 2026-05-10 11:49:31 +02:00
masato-ka 0e6114ac36 fix(train): restrict legacy RA-BC migration to JSON checkpoints only (#3490)
* fix(train): restrict legacy RA-BC migration to JSON checkpoints only

_migrate_legacy_rabc_fields was called for all config files, causing
json.load to raise DecodeError when a YAML/TOML config was passed to
lerobot-train for a new training run. Guard the block with an
.endswith(".json") check so migration only runs when resuming from
a JSON checkpoint.
2026-05-08 20:27:01 +02:00
Steven Palma c8ce413d73 fix(robots): allign lekiwi default with so100 use_degrees (#3531) 2026-05-07 17:52:34 +02:00
Pepijn 82dffde7fa fix(ci): speed up multi-task benchmark evals (parallelize + cap VLABench steps) (#3529)
* fix(ci): run multi-task benchmark evals 5-at-a-time in parallel

The eval script supports running tasks concurrently via a
ThreadPoolExecutor (env.max_parallel_tasks). Apply it to the four
multi-task benchmark CI jobs (RoboTwin, RoboCasa, RoboMME, LIBERO-plus
— 8-10 tasks/task_ids each) so they finish in ~2 waves of 5 instead of
running sequentially. Single-task jobs (Libero, MetaWorld, RoboCerebra)
are unchanged.

* fix(ci): cap VLABench smoke eval at 50 steps per task

VLABench's default episode_length is 500 steps; with 10 tasks at ~1 it/s
the smoke eval took ~80 minutes of rollouts on top of the image build.
The eval is a pipeline smoke test (running_success_rate stays at 0% on
this short rollout anyway), so we don't need full episodes — cap each
task at 50 steps to bring total rollout time down ~10x.

* fix(ci): run VLABench tasks 5-at-a-time in parallel

The eval script already supports running multiple tasks concurrently via
a ThreadPoolExecutor (env.max_parallel_tasks). Set it to 5 so the 10
VLABench tasks finish in ~2 waves instead of running sequentially.
2026-05-07 13:37:16 +02:00
Ville Kuosmanen eaf0218bc8 feat(policy): use pretrained vision encoder weights by default for diffusion and vqbet (#3202)
* feat: add pretrained vision encoder weights for diffusion and vqbet

* fix test by re-generating artifacts

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-07 12:10:38 +02:00
Pepijn a0e52d52fe fix(ci): bump robotwin benchmark image to CUDA 12.6 (#3525)
The robotwin benchmark Dockerfile still installed cuda-nvcc-12-4 and
cuda-cudart-dev-12-4 after #3505 upgraded the base image to CUDA 12.6.3
on Ubuntu 24.04. Those packages aren't available in the ubuntu2404 CUDA
repo, so the build failed at apt-get install. Bumping both to -12-6 to
match the base image.
2026-05-07 11:11:12 +02:00
90 changed files with 6119 additions and 3684 deletions
+6
View File
@@ -382,6 +382,7 @@ jobs:
--policy.path=\"\$ROBOTWIN_POLICY\" \
--env.type=robotwin \
--env.task=\"\$ROBOTWIN_TASKS\" \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -482,6 +483,7 @@ jobs:
--policy.path=lerobot/smolvla_robocasa \
--env.type=robocasa \
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -693,6 +695,7 @@ jobs:
--env.task=\"\$ROBOMME_TASKS\" \
--env.dataset_split=test \
--env.task_ids=[0] \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -800,6 +803,7 @@ jobs:
--env.type=libero_plus \
--env.task=\"\$LIBERO_PLUS_SUITE\" \
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
@@ -900,6 +904,8 @@ jobs:
--policy.path=lerobot/smolvla_vlabench \
--env.type=vlabench \
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
--env.episode_length=50 \
--env.max_parallel_tasks=5 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--eval.use_async_envs=false \
+2 -1
View File
@@ -152,13 +152,14 @@ jobs:
BASE_VERSION="${VERSION%%-*}"
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
uv pip install \
--torch-backend cpu \
--index-url https://test.pypi.org/simple/ \
--extra-index-url https://pypi.org/simple \
--index-strategy unsafe-best-match \
"lerobot[all]==$BASE_VERSION"
else
echo "Installing release version $VERSION from PyPI..."
uv pip install "lerobot[all]==$VERSION"
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
fi
- name: Check lerobot version
run: uv run python -c "import lerobot; print(lerobot.__version__)"
+2 -2
View File
@@ -19,8 +19,8 @@ on:
workflow_dispatch:
# Runs at 02:00
schedule:
- cron: "0 2 * * *"
# schedule:
# - cron: "0 2 * * *"
env:
CLOSE_ISSUE_MESSAGE: >
+2
View File
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
All policies typically train for **510 epochs** (see §7).
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
+1 -1
View File
@@ -109,7 +109,7 @@ lerobot-train \
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
## Inference & Evaluation
-288
View File
@@ -1,288 +0,0 @@
# Video benchmark
## Questions
What is the optimal trade-off between:
- maximizing loading time with random access,
- minimizing memory space on disk,
- maximizing success rate of policies,
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
How to encode videos?
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
How to decode videos?
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
**Data augmentations**
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
### Encoding parameters
| parameter | values |
| ----------- | ------------------------------------------------------------ |
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
| **pix_fmt** | `yuv444p`, `yuv420p` |
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
### Decoding parameters
**Decoder**
We tested two video decoding backends from torchvision:
- `pyav`
- `video_reader` (requires to build torchvision from source)
**Requested timestamps**
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
- `1_frame`: 1 frame,
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
## Metrics
**Data compression ratio (lower is better)**
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
**Loading time ratio (lower is better)**
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
**Average Mean Square Error (lower is better)**
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
**Average Peak Signal to Noise Ratio (higher is better)**
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
<!-- **Loss of a pretrained policy (higher is better)** (not available)
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
**Success rate after retraining (higher is better)** (not available)
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
## How the benchmark works
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
This gives a unique set of encoding parameters which is used to encode the episode.
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
These are then all concatenated to a single table ready for analysis.
## Caveats
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
Additional encoding parameters exist that are not included in this benchmark. In particular:
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`
- `ffmpegio`
- `decord`
- `nvc`
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
## Install
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
## Adding a video decoder
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
You can easily add a new decoder to benchmark by adding it to this function in the script:
```diff
def decode_video_frames(
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
+ elif backend == ["your_decoder"]:
+ return your_decoder_function(
+ video_path, timestamps, tolerance_s, backend
+ )
else:
raise NotImplementedError(backend)
```
## Example
For a quick run, you can try these parameters:
```bash
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 2 20 None \
--crf 10 40 None \
--timestamps-modes 1_frame 2_frames \
--backends pyav video_reader \
--num-samples 5 \
--num-workers 5 \
--save-frames 0
```
## Results
### Reproduce
We ran the benchmark with the following parameters:
```bash
# h264 and h265 encodings
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libx264 libx265 \
--pix-fmt yuv444p yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav video_reader \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
# av1 encoding (only compatible with yuv420p and pyav decoder)
python benchmark/video/run_video_benchmark.py \
--output-dir outputs/video_benchmark \
--repo-ids \
lerobot/pusht_image \
lerobot/aloha_mobile_shrimp_image \
lerobot/paris_street \
lerobot/kitchen \
--vcodec libsvtav1 \
--pix-fmt yuv420p \
--g 1 2 3 4 5 6 10 15 20 40 None \
--crf 0 5 10 15 20 25 30 40 50 None \
--timestamps-modes 1_frame 2_frames 6_frames \
--backends pyav \
--num-samples 50 \
--num-workers 5 \
--save-frames 1
```
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
### Parameters selected for LeRobotDataset
Considering these results, we chose what we think is the best set of encoding parameter:
- vcodec: `libsvtav1`
- pix-fmt: `yuv420p`
- g: `2`
- crf: `30`
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
### Summary
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
| video_images_size_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
| | libx264 | | libx265 | | libsvtav1 |
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
| | | vcodec | pix_fmt | | | |
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
| | | libx264 | | libx265 | | libsvtav1 |
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
-488
View File
@@ -1,488 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assess the performance of video decoding in various configurations.
This script will benchmark different video encoding and decoding parameters.
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
"""
import argparse
import datetime as dt
import itertools
import random
import shutil
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import einops
import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames,
encode_video_frames,
)
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.utils import TimerManager
BASE_ENCODING = OrderedDict(
[
("vcodec", "libx264"),
("pix_fmt", "yuv444p"),
("g", 2),
("crf", None),
# TODO(aliberts): Add fastdecode
# ("fastdecode", 0),
]
)
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
def parse_int_or_none(value) -> int | None:
if value.lower() == "none":
return None
try:
return int(value)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
def check_datasets_formats(repo_ids: list) -> None:
for repo_id in repo_ids:
dataset = LeRobotDataset(repo_id)
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
)
def get_directory_size(directory: Path) -> int:
total_size = 0
for item in directory.rglob("*"):
if item.is_file():
total_size += item.stat().st_size
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
frames.append(frame)
return torch.stack(frames)
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
return
save_dir.mkdir(parents=True, exist_ok=True)
for i, ts in enumerate(timestamps):
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
return
imgs_dir.mkdir(parents=True, exist_ok=True)
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
case "1_frame":
frame_indexes = [idx]
case "2_frames":
frame_indexes = [idx - 1, idx]
case "2_frames_4_space":
frame_indexes = [idx - 5, idx]
case "6_frames":
frame_indexes = [idx - i for i in range(6)][::-1]
case _:
raise ValueError(timestamps_mode)
return [idx / fps for idx in frame_indexes]
def benchmark_decoding(
imgs_dir: Path,
video_path: Path,
timestamps_mode: str,
backend: str,
ep_num_images: int,
fps: int,
num_samples: int = 50,
num_workers: int = 4,
save_frames: bool = False,
) -> dict:
def process_sample(sample: int, lock: Lock):
time_benchmark = TimerManager(log=False)
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
num_frames = len(timestamps)
result = {
"psnr_values": [],
"ssim_values": [],
"mse_values": [],
}
with time_benchmark, lock:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
with time_benchmark:
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
)
if save_frames and sample == 0:
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
return result
load_times_video_ms = []
load_times_images_ms = []
mse_values = []
psnr_values = []
ssim_values = []
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
# Use a single shared lock for all worker threads
shared_lock = Lock()
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
psnr_values.extend(result["psnr_values"])
ssim_values.extend(result["ssim_values"])
mse_values.extend(result["mse_values"])
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
return {
"avg_load_time_video_ms": avg_load_time_video_ms,
"avg_load_time_images_ms": avg_load_time_images_ms,
"video_images_load_time_ratio": video_images_load_time_ratio,
"avg_mse": float(np.mean(mse_values)),
"avg_psnr": float(np.mean(psnr_values)),
"avg_ssim": float(np.mean(ssim_values)),
}
def benchmark_encoding_decoding(
dataset: LeRobotDataset,
video_path: Path,
imgs_dir: Path,
encoding_cfg: dict,
decoding_cfg: dict,
num_samples: int,
num_workers: int,
save_frames: bool,
overwrite: bool = False,
seed: int = 1337,
) -> list[dict]:
fps = dataset.fps
if overwrite or not video_path.is_file():
tqdm.write(f"encoding {video_path}")
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=encoding_cfg["vcodec"],
pix_fmt=encoding_cfg["pix_fmt"],
g=encoding_cfg.get("g"),
crf=encoding_cfg.get("crf"),
# fast_decode=encoding_cfg.get("fastdecode"),
overwrite=True,
)
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)
video_images_size_ratio = video_size_bytes / images_size_bytes
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
timestamps_mode,
backend,
ep_num_images,
fps,
num_samples,
num_workers,
save_frames,
)
benchmark_row.update(
**{
"repo_id": dataset.repo_id,
"resolution": f"{width} x {height}",
"num_pixels": num_pixels,
"video_size_bytes": video_size_bytes,
"images_size_bytes": images_size_bytes,
"video_images_size_ratio": video_images_size_ratio,
"timestamps_mode": timestamps_mode,
"backend": backend,
},
**encoding_cfg,
)
benchmark_table.append(benchmark_row)
return benchmark_table
def main(
output_dir: Path,
repo_ids: list[str],
vcodec: list[str],
pix_fmt: list[str],
g: list[int],
crf: list[int],
# fastdecode: list[int],
timestamps_modes: list[str],
backends: list[str],
num_samples: int,
num_workers: int,
save_frames: bool,
):
check_datasets_formats(repo_ids)
encoding_benchmarks = {
"g": g,
"crf": crf,
# "fastdecode": fastdecode,
}
decoding_benchmarks = {
"timestamps_modes": timestamps_modes,
"backends": backends,
}
headers = ["repo_id", "resolution", "num_pixels"]
headers += list(BASE_ENCODING.keys())
headers += [
"timestamps_mode",
"backend",
"video_size_bytes",
"images_size_bytes",
"video_images_size_ratio",
"avg_load_time_video_ms",
"avg_load_time_images_ms",
"video_images_load_time_ratio",
"avg_mse",
"avg_psnr",
"avg_ssim",
]
file_paths = []
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
benchmark_table = []
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
dataset = LeRobotDataset(repo_id)
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for duet in [
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
for unique_combination in itertools.product(*encoding_benchmarks.values())
]:
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
for key, value in duet.items():
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
# Save intermediate results
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
now = dt.datetime.now()
csv_path = (
output_dir
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
)
benchmark_df.to_csv(csv_path, header=True, index=False)
file_paths.append(csv_path)
del benchmark_df
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_df.to_csv(concatenated_path, header=True, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/video_benchmark"),
help="Directory where the video benchmark outputs are written.",
)
parser.add_argument(
"--repo-ids",
type=str,
nargs="*",
default=[
"lerobot/pusht_image",
"lerobot/aloha_mobile_shrimp_image",
"lerobot/paris_street",
"lerobot/kitchen",
],
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
)
parser.add_argument(
"--vcodec",
type=str,
nargs="*",
default=["h264", "hevc", "libsvtav1"],
help="Video codecs to be tested",
)
parser.add_argument(
"--pix-fmt",
type=str,
nargs="*",
default=["yuv444p", "yuv420p"],
help="Pixel formats (chroma subsampling) to be tested",
)
parser.add_argument(
"--g",
type=parse_int_or_none,
nargs="*",
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
help="Group of pictures sizes to be tested.",
)
parser.add_argument(
"--crf",
type=parse_int_or_none,
nargs="*",
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
help="Constant rate factors to be tested.",
)
# parser.add_argument(
# "--fastdecode",
# type=int,
# nargs="*",
# default=[0, 1],
# help="Use the fastdecode tuning option. 0 disables it. "
# "For libx264 and libx265/hevc, only 1 is possible. "
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
# )
parser.add_argument(
"--timestamps-modes",
type=str,
nargs="*",
default=[
"1_frame",
"2_frames",
"2_frames_4_space",
"6_frames",
],
help="Timestamps scenarios to be tested.",
)
parser.add_argument(
"--backends",
type=str,
nargs="*",
default=["torchcodec", "pyav"],
help="Torchvision decoding backend to be tested.",
)
parser.add_argument(
"--num-samples",
type=int,
default=50,
help="Number of samples for each encoding x decoding config.",
)
parser.add_argument(
"--num-workers",
type=int,
default=10,
help="Number of processes for parallelized sample processing.",
)
parser.add_argument(
"--save-frames",
type=int,
default=0,
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
)
args = parser.parse_args()
main(**vars(args))
+1 -1
View File
@@ -35,7 +35,7 @@ USER root
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
libvulkan1 vulkan-tools \
&& mkdir -p /usr/share/vulkan/icd.d \
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
+1 -1
View File
@@ -18,7 +18,7 @@
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
# Configure the base image for CI with GPU access
ARG CUDA_VERSION=12.6.3
ARG CUDA_VERSION=12.8.1
ARG OS_VERSION=24.04
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
+7 -5
View File
@@ -8,7 +8,7 @@
- local: il_robots
title: Imitation Learning for Robots
- local: bring_your_own_policies
title: Bring Your Own Policies
title: Adding a Policy
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
@@ -24,6 +24,12 @@
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: hardware_guide
title: Compute Hardware Guide
- local: torch_accelerators
title: PyTorch accelerators
title: "Compute & Hardware"
- sections:
- local: lerobot-dataset-v3
title: Using LeRobotDataset
@@ -142,10 +148,6 @@
- local: cameras
title: Cameras
title: "Sensors"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
title: "Supported Hardware"
- sections:
- local: notebooks
title: Notebooks
+220 -81
View File
@@ -1,60 +1,37 @@
# Bring Your Own Policies
# Adding a Policy
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
## Step 1: Create a Policy Package
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
### Package Structure
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```
---
### Package Configuration
## Anatomy of a policy
Set up your `pyproject.toml`:
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
### Configuration class
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
```python
# configuration_my_custom_policy.py
# configuration_my_policy.py
from dataclasses import dataclass, field
from lerobot.configs import PreTrainedConfig
from lerobot.optim import AdamWConfig
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@PreTrainedConfig.register_subclass("my_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.
class MyPolicyConfig(PreTrainedConfig):
"""Configuration class for MyPolicy.
Args:
n_obs_steps: Number of observation steps to use as input
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
raise ValueError("n_action_steps cannot exceed horizon")
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
"""Validate input/output feature compatibility.
Call this explicitly from your policy's __init__ — the base class does not.
"""
if not self.image_features:
raise ValueError("MyCustomPolicy requires at least one image feature.")
raise ValueError("MyPolicy requires at least one image feature.")
if self.action_feature is None:
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
raise ValueError("MyPolicy requires 'action' in output_features.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self):
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
return None
@property
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Relative timestep offsets for the action chunk the dataset loader returns.
"""
"""Relative timestep offsets for the action chunk the dataset loader returns."""
return list(range(self.horizon))
@property
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
return None
```
## Step 3: Implement the Policy Class
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
### Policy class
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
```python
# modeling_my_custom_policy.py
# modeling_my_policy.py
import torch
import torch.nn as nn
from typing import Any
from lerobot.policies import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .configuration_my_policy import MyPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
name = "my_custom_policy"
class MyPolicy(PreTrainedPolicy):
config_class = MyPolicyConfig # must match the string in @register_subclass
name = "my_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
super().__init__(config, dataset_stats)
config.validate_features() # not called automatically by the base class
self.config = config
self.model = ... # your nn.Module here
def reset(self):
"""Reset episode state."""
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
...
def get_optim_params(self) -> dict:
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
...
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
"""Return a single action for the current timestep (called at inference)."""
"""Return a single action for the current timestep (called every step at inference)."""
...
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
"""Compute the training loss.
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
logging-friendly Python natives (no tensors with gradients).
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
timesteps padded because the episode ended before `horizon` steps, you
timesteps padded because the episode ended before `horizon` steps; you
can exclude those from your loss.
"""
actions = batch[ACTION]
action_is_pad = batch.get("action_is_pad")
...
return {"loss": ...}
return loss, {"some_loss_component": some_loss_component.item()}
```
## Step 4: Add Data Processors
The methods called by the train/eval loops:
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
| Method | Used by | What it does |
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
### Processor functions
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
```python
# processor_my_custom_policy.py
# processor_my_policy.py
from typing import Any
import torch
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
def make_my_custom_policy_pre_post_processors(
def make_my_policy_pre_post_processors(
config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
return preprocessor, postprocessor
```
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
**Important function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
## Step 5: Package Initialization
---
Expose your classes in the package's `__init__.py`:
## Path A: Out-of-tree plugin
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
### Package structure
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
```bash
lerobot_policy_my_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_policy/
├── __init__.py
├── configuration_my_policy.py
├── modeling_my_policy.py
└── processor_my_policy.py
```
### `pyproject.toml`
```toml
[project]
name = "lerobot_policy_my_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.12"
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
### Package `__init__.py`
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
```python
# __init__.py
@@ -204,44 +239,148 @@ except ImportError:
"lerobot is not installed. Please install lerobot to use this policy package."
)
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
from .configuration_my_policy import MyPolicyConfig
from .modeling_my_policy import MyPolicy
from .processor_my_policy import make_my_policy_pre_post_processors
__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
"MyPolicyConfig",
"MyPolicy",
"make_my_policy_pre_post_processors",
]
```
## Step 6: Installation and Usage
### Install Your Policy Package
### Install and use
```bash
cd lerobot_policy_my_custom_policy
cd lerobot_policy_my_policy
pip install -e .
# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
pip install lerobot_policy_my_policy
```
### Use Your Policy
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
```bash
lerobot-train \
--policy.type my_custom_policy \
--policy.type my_policy \
--env.type pusht \
--steps 200000
```
## Examples and Community Contributions
---
## Path B: Contributing in-tree
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
### In-tree layout
```
src/lerobot/policies/my_policy/
├── __init__.py # re-exports config + modeling + processor factory
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
├── processor_my_policy.py # make_my_policy_pre_post_processors
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
```
Two notes:
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
### Wiring
Three places need to know about your policy. All by name.
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
Mirror an existing policy that's structurally similar to yours; the diff is small.
### Heavy / optional dependencies
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
```python
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
else:
DDIMScheduler = None # keeps the symbol bindable at import time
class DiffusionPolicy(PreTrainedPolicy):
def __init__(self, config):
require_package("diffusers", extra="diffusion")
super().__init__(config)
...
```
This way:
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
- Type checkers see the real symbol.
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
### Benchmarks and a published checkpoint
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
```markdown
## Results
Evaluated on LIBERO with `lerobot/<policy>_libero`:
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 87.5% | 50 |
| libero_object | 93.0% | 50 |
| libero_goal | 81.5% | 50 |
| libero_10 | 62.0% | 50 |
| **average** | **81.0%** | 200 |
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
```
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
### PR checklist
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
---
## Examples and community contributions
Check out these example policy implementations:
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
Share your policy implementations with the community! 🤗
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
+98
View File
@@ -0,0 +1,98 @@
# Compute HW Guide for LeRobot Training
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
## Memory by policy group
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30100% over a forward+backward pass alone.
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
| Light BC | `act`, `vqbet`, `tdmpc` | ~26GB | Laptop GPU (RTX 3060), L4, A10G |
| Diffusion | `diffusion`, `multi_task_dit` | ~814GB | RTX 4070+ / L4 / A10G |
| Small VLA | `smolvla` | ~1016GB | RTX 4080+ / L4 / A10G |
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~2440GB | A100 40 GB+ (24 GB tight at BS 1) |
| Multimodal | `groot`, `eo1` | ~2440GB | A100 40 GB+ |
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
## Training time
Robotics imitation learning typically converges in **510 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
```text
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
total_steps = epochs × steps_per_epoch
wall_clock ≈ total_steps × per_step_time
```
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
### Common scenarios
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
| Setup | Policy | Batch | Wall-clock |
| ------------------------------------ | -------------- | ----- | ---------: |
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~3060min |
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~24h |
| Single L4 / A10G (24 GB) | `act` | 8 | ~12h |
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~36h |
| Single A100 40 GB | `smolvla` | 16 | ~12h |
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~48h |
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~3060min |
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~12h |
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~614h |
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
### Multi-GPU matters a lot
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
### Schedule and checkpoints
If you shorten training (e.g. 5k10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
## Where to run
VRAM is the first filter. Within a tier, pick by budget and availability — the `$``$$$$` columns are relative; check current pricing on the provider you actually use.
| Class | VRAM | Tier | Comfortable for |
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
| L4 / A10G (cloud) | 24 GB | `$$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
### Hugging Face Jobs
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
```bash
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
bash -c "nvidia-smi && lerobot-train \
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
```
Notes:
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
+40 -37
View File
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
### Understanding Configuration
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
<!-- prettier-ignore-start -->
```python
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
```
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
**Example Configuration**
```json
"end_effector_bounds": {
"max": [0.24, 0.20, 0.10],
"min": [0.16, -0.08, 0.03]
{
"env": {
"processor": {
"inverse_kinematics": {
"end_effector_bounds": {
"max": [0.24, 0.2, 0.1],
"min": [0.16, -0.08, 0.03]
}
}
}
}
}
```
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
<!-- prettier-ignore-start -->
```python
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
"""Configuration for the SO100FollowerEndEffector robot."""
class InverseKinematicsConfig:
"""Configuration for inverse kinematics processing."""
# Default bounds for the end-effector position (in meters)
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
default_factory=lambda: {
"min": [-1.0, -1.0, -1.0], # min x, y, z
"max": [1.0, 1.0, 1.0], # max x, y, z
}
)
urdf_path: str | None = None
target_frame_name: str | None = None
# bounds for the end-effector in x,y,z direction
end_effector_bounds: dict[str, list[float]] | None = None
# maximum step size for the end-effector in x,y,z direction
end_effector_step_sizes: dict[str, float] | None = None
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
default_factory=lambda: {
"x": 0.02,
"y": 0.02,
"z": 0.02,
}
)
class HILSerlProcessorConfig:
...
# maximum gripper position that the gripper will be open at
max_gripper_pos: float | None = 100.0
```
<!-- prettier-ignore-end -->
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
**Collecting a Dataset for the reward classifier**
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
@@ -658,7 +661,7 @@ Example configuration section for data collection:
},
"dataset": {
"repo_id": "hf_username/dataset_name",
"dataset_root": "data/your_dataset",
"root": "data/your_dataset",
"task": "reward_classifier_task",
"num_episodes_to_record": 20,
"replay_episode": null,
@@ -671,7 +674,7 @@ Example configuration section for data collection:
**Reward Classifier Configuration**
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
- **model_type**: `"cnn"` or `"transformer"`
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"repo_id": "hf_username/dataset_name",
"root": null
},
"policy": {
"reward_model": {
"type": "reward_classifier",
"model_name": "helper2424/resnet10",
"model_type": "cnn",
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
"dropout_rate": 0.1,
"learning_rate": 1e-4,
"device": "cuda",
"use_amp": true,
"input_features": {
"observation.images.front": {
"type": "VISUAL",
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
**Configuration Setup**
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
1. Configure the policy settings (`type="sac"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
3. Set `dataset` to your cropped dataset
4. Configure environment settings with crop parameters
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
Some configuration values have a disproportionate impact on training stability and speed:
- **`temperature_init`** (`policy.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`temperature_init`** (`algorithm.temperature_init`) initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
- **`storage_device`** (`policy.storage_device`) device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
+50
View File
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
### PyTorch CUDA variant (Linux only)
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
<!-- prettier-ignore-start -->
<hfoptions id="cuda_variant">
<hfoption id="uv-source">
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
To override for a different CUDA variant:
```bash
uv pip install --force-reinstall torch torchvision \
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
```
</hfoption>
<hfoption id="pip-conda">
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
To pick a specific CUDA variant:
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
```bash
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
pip install -e ".[all]" # source
# — or —
pip install lerobot # from PyPI
```
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
```bash
uv pip install --torch-backend cu128 lerobot
```
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
</hfoption>
</hfoptions>
<!-- prettier-ignore-end -->
### Troubleshooting
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
+136
View File
@@ -0,0 +1,136 @@
# OMX Follower — Cube Pick And Place Example
This is an example of what is possible to do with LeRobot on a physical setup.
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
## Hardware
| Component | Value |
| --------- | ------------------------------------ |
| Robot | OMX Follower |
| Cameras | 2× OpenCV cameras (wrist + top-down) |
## Scripts
| Script | Purpose |
| ---------------------- | --------------------------------------------------------------- |
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
## Setup
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
```bash
export ROBOT_PORT=/dev/ttyACM0
export TELEOP_PORT=/dev/ttyACM1
export HF_USERNAME=<your_hf_username>
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
```
## Step 1 — Collect Data
```bash
lerobot-record \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--teleop.type=omx_leader \
--teleop.port=$TELEOP_PORT \
--teleop.id=omx_leader \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
### Bonus Auto-Collect script
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
```bash
python -m examples.omx.record_grab \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--dataset.root=data/omx_pickandplace \
--dataset.num_episodes=50 \
--dataset.single_task="Pick the cube and place it in the blue square" \
--dataset.streaming_encoding=true \
--dataset.push_to_hub=true
```
Each episode:
1. The arm grabs the cube from the center of the workspace and places it at a random position.
2. The arm returns to HOME.
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
## Step 2 — Train
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
```bash
lerobot-train \
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
--policy.type=act \
--output_dir=outputs/train/omx_pickandplace_act \
--policy.device=cuda \
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
--steps=20000 \
--wandb.enable=true
```
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
## Step 3 — Rollout
Use the `lerobot-rollout` CLI with base strategy:
```bash
lerobot-rollout \
--strategy.type=base \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
```
For continuous recording with automatic upload (sentry mode):
```bash
lerobot-rollout \
--strategy.type=sentry \
--strategy.upload_every_n_episodes=10 \
--robot.type=omx_follower \
--robot.port=$ROBOT_PORT \
--robot.id=omx_follower \
--robot.cameras="$ROBOT_CAMERAS" \
--policy.path=$HF_USERNAME/omx_pickandplace_act \
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
```
## Environment Reset Utility
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
`reset_environment.py` can be run standalone to prepare the workspace:
```bash
# Grab cube + place it at a random position on the left side
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
```
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
+422
View File
@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""
Auto-record grab episodes for the OMX robot arm.
Each episode cycle:
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
2. HOME — return arm to home with gripper open
3. record_grab — execute a targeted grab to the stored position while recording
observations + actions to a LeRobotDataset
Usage (run from repo root):
python -m examples.omx.record_grab \\
--robot.type=omx_follower \\
--robot.port=/dev/ttyACM0 \\
--robot.id=omx_follower \\
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
--dataset.repo_id=<hf_username>/<dataset_name> \\
--dataset.root=data/omx_grab \\
--dataset.num_episodes=50 \\
--dataset.single_task="Grab the cube" \\
--dataset.streaming_encoding=true
"""
import logging
from dataclasses import dataclass
from pprint import pformat
import numpy as np
from lerobot.cameras import CameraConfig # noqa: F401
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
aggregate_pipeline_dataset_features,
create_initial_features,
)
from lerobot.processor import make_default_processors
from lerobot.robots import RobotConfig, make_robot_from_config
from lerobot.robots.omx_follower import OmxFollower
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.robot_utils import precise_sleep
from .reset_environment import (
APPROACH_SPEED,
GRIPPER_CLOSE_POS,
HOME_POSE,
PUSH_END_ELBOW_FLEX,
PUSH_END_SHOULDER_LIFT,
PUSH_START_ELBOW_FLEX,
PUSH_START_SHOULDER_LIFT,
array_to_pose,
grab_cube,
horizontal_wrist_flex,
move_to_pose,
place_cube,
pose_to_array,
)
# ── Grab-episode motion parameters ────────────────────────────────────────────
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
GRAB_RAISE_SL_OFFSET = 20.0
GRAB_LOWER_SPEED = 20.0
RECORD_SPEED = 30.0
# Pose the arm travels to after closing the gripper (cube held).
GRAB_CARRY_POSE = {
"shoulder_pan.pos": -23.0,
"shoulder_lift.pos": 5.0,
"elbow_flex.pos": 18.0,
"wrist_flex.pos": -14.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
}
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
# Cube-approach and carry poses are never jittered to preserve precision.
_JITTER_LIMITS: dict[str, float] = {
"shoulder_pan.pos": 5.0,
"shoulder_lift.pos": 4.0,
"elbow_flex.pos": 4.0,
"wrist_flex.pos": 3.0,
"wrist_roll.pos": 2.0,
"gripper.pos": 0.0,
}
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
"""Return a copy of pose with independent per-joint random perturbations."""
return {
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
}
def _random_stuck_pose(rng: np.random.Generator) -> dict:
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
table-safe envelope across the full sl range:
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
sl= 0 → ef ∈ [-25, 25] (mid reach)
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
wrist_flex is randomly offset from the horizontal value.
"""
pan = float(rng.uniform(-5.0, 35.0))
sl = float(rng.uniform(-50.0, 30.0))
if sl <= 0.0:
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
ef_lo = alpha * -25.0 # 0 → -25
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
else:
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
ef = float(rng.uniform(ef_lo, ef_hi))
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf,
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
"gripper.pos": GRIPPER_CLOSE_POS,
}
logger = logging.getLogger(__name__)
@dataclass
class OmxRecordGrabConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Resume recording on an existing dataset.
resume: bool = False
# Fraction of episodes that start from a random stuck pose (gripper closed) to
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
recovery_prob: float = 0.5
def record_episode_spline(
robot: OmxFollower,
waypoints: list[dict],
speeds: list[float],
dataset: LeRobotDataset,
task: str,
) -> None:
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
Segment durations are parameterized from the maximum absolute joint delta
between consecutive waypoints divided by the requested segment speed,
producing non-uniform timing in joint space. Interior tangents are derived
from the adjacent per-segment velocities, with clamped (zero-velocity)
endpoints so the arm starts and stops smoothly. Each segment is cubic
Hermite, giving C1 continuity at every waypoint.
"""
pts = [pose_to_array(w) for w in waypoints]
n = len(pts)
# Steps and duration per segment
n_steps_list = []
timestamps = []
for i in range(n - 1):
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
n_steps_list.append(ns)
timestamps.append(ns / dataset.fps)
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
vels = [np.zeros_like(pts[0])]
for i in range(1, n - 1):
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
vels.append(0.5 * (v_prev + v_next))
vels.append(np.zeros_like(pts[0]))
dt = 1.0 / dataset.fps
for seg in range(n - 1):
ns = n_steps_list[seg]
if ns == 0:
continue
p0, p1 = pts[seg], pts[seg + 1]
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
m0 = vels[seg] * timestamps[seg]
m1 = vels[seg + 1] * timestamps[seg]
for step in range(1, ns + 1):
t = step / ns
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = -2 * t**3 + 3 * t**2
h11 = t**3 - t**2
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
action = array_to_pose(commanded)
robot.send_action(action)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": task})
precise_sleep(dt)
def record_grab_episode(
robot: OmxFollower,
dataset: LeRobotDataset,
pan: float,
t: float,
task: str,
recovery_start: bool = False,
) -> None:
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
Normal sequence (initial HOME move is NOT recorded):
HOME → raised approach above cube → lower → close gripper
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
(gripper closed) without recording, then recording begins from there:
stuck_pose → raised approach above cube → [normal grab sequence from there]
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
"""
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
sl_raised = sl - GRAB_RAISE_SL_OFFSET
wf_horizontal = horizontal_wrist_flex(sl, ef)
rng = np.random.default_rng()
if recovery_start:
stuck_pose = _random_stuck_pose(rng)
logger.info(f"Recovery start: {stuck_pose}")
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
first_waypoints = [stuck_pose]
first_speeds = []
else:
jittery_start = _jitter_pose(HOME_POSE, rng)
move_to_pose(robot, jittery_start, APPROACH_SPEED)
first_waypoints = [jittery_start]
first_speeds = []
waypoints = first_waypoints + [
{ # raised approach: arm above cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # lower onto cube — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{ # close gripper — no jitter: precision needed
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": ef,
"wrist_flex.pos": wf_horizontal,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
_jitter_pose(
{ # raise with cube
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl_raised,
"elbow_flex.pos": ef,
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
_jitter_pose(
{ # retract: fold arm toward HOME before sweeping to carry zone
"shoulder_pan.pos": pan * 0.25,
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": GRIPPER_CLOSE_POS,
},
rng,
),
GRAB_CARRY_POSE, # no jitter: target drop zone
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
HOME_POSE,
]
speeds = first_speeds + [
RECORD_SPEED, # (HOME →) raised approach
GRAB_LOWER_SPEED, # raised approach → lower
GRAB_LOWER_SPEED, # lower → close gripper
RECORD_SPEED, # close gripper → raise
RECORD_SPEED, # raise → retract
RECORD_SPEED, # retract → carry pose
RECORD_SPEED, # carry pose → drop
RECORD_SPEED, # drop → HOME
]
record_episode_spline(robot, waypoints, speeds, dataset, task)
# Dwell at HOME for ~0.5 s before next episode
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
dt = 1.0 / dataset.fps
for _ in range(int(dataset.fps * 0.5)):
robot.send_action(HOME_POSE)
obs = robot.get_observation()
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
dataset.add_frame({**obs_frame, **home_action, "task": task})
precise_sleep(dt)
@parser.wrap()
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger.info(pformat(cfg))
robot = make_robot_from_config(cfg.robot)
use_videos = cfg.dataset.video
teleop_action_processor, _, robot_obs_processor = make_default_processors()
dataset_features = combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=teleop_action_processor,
initial_features=create_initial_features(action=robot.action_features),
use_videos=use_videos,
),
aggregate_pipeline_dataset_features(
pipeline=robot_obs_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=use_videos,
),
)
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
dataset = None
try:
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
root=cfg.dataset.root,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
else:
cfg.dataset.stamp_repo_id()
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=use_videos,
streaming_encoding=cfg.dataset.streaming_encoding,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
vcodec=cfg.dataset.vcodec,
encoder_threads=cfg.dataset.encoder_threads,
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
if num_cameras > 0
else 0,
)
robot.connect(calibrate=True)
rng = np.random.default_rng()
with VideoEncodingManager(dataset):
for episode_idx in range(cfg.dataset.num_episodes):
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
logger.info("Step 1: grabbing and placing cube...")
grab_cube(robot)
pan, t = place_cube(robot)
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
record_grab_episode(
robot,
dataset,
pan,
t,
cfg.dataset.single_task,
recovery_start=recovery_start,
)
dataset.save_episode()
logger.info(f"Episode {episode_idx + 1} saved.")
finally:
if dataset:
dataset.finalize()
if robot.is_connected:
robot.disconnect()
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
return dataset
if __name__ == "__main__":
record_grab()
+267
View File
@@ -0,0 +1,267 @@
#!/usr/bin/env python3
"""
Auto-reset and cube-grab utility for the OMX robot arm.
Provides:
- grab_cube(robot): sweep workspace, center cube, close gripper
- place_cube(robot): carry cube to a random position, release
Standalone usage (run from repo root):
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
To read current joint values for calibration, add after robot.connect():
obs = robot.get_observation()
print({k: round(obs[k], 1) for k in JOINT_NAMES})
robot.disconnect(); raise SystemExit
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
Linear interpolation preserves this constraint between any two poses that satisfy it.
"""
import argparse
import logging
import numpy as np
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
from lerobot.robots.robot import Robot
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
# ── Poses ─────────────────────────────────────────────────────────────────────
HOME_POSE = {
"shoulder_pan.pos": 0.0,
"shoulder_lift.pos": -50.0,
"elbow_flex.pos": 50.0,
"wrist_flex.pos": 0.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
SWEEP_WAYPOINTS = [
{
"shoulder_pan.pos": -60.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -20.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": -30.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -60.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
{
"shoulder_pan.pos": 20.0,
"shoulder_lift.pos": 50.0,
"elbow_flex.pos": -55.0,
"wrist_flex.pos": -5.0,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
},
]
# ── Motion parameters ─────────────────────────────────────────────────────────
CONTROL_HZ = 30
APPROACH_SPEED = 50.0
SWEEP_SPEED = 40.0
# ── Grab-sequence parameters ──────────────────────────────────────────────────
GRAB_PAN = 0.0
SWEEP_LEFT_PAN = -60.0
SWEEP_RIGHT_PAN = 60.0
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
SWEEP_END_PAN_RANGE = (15.0, 20.0)
SWEEP_LOW_SHOULDER_LIFT = 50.0
SWEEP_LOW_ELBOW_FLEX_START = -60.0
SWEEP_LOW_ELBOW_FLEX_END = -55.0
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
PUSH_START_SHOULDER_LIFT = 0.0
PUSH_START_ELBOW_FLEX = 45.0
PUSH_END_SHOULDER_LIFT = 50.0
PUSH_END_ELBOW_FLEX = -50.0
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
# Does not affect the grab-target interpolation in record_grab.py.
PUSH_RAISE_OFFSET = 5.0
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
GRIPPER_CLOSE_POS = 50.0
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
JOINT_NAMES = [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
]
# ── Helpers ───────────────────────────────────────────────────────────────────
def pose_to_array(pose: dict) -> np.ndarray:
return np.array([pose[k] for k in JOINT_NAMES])
def array_to_pose(arr: np.ndarray) -> dict:
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
sl = SWEEP_LOW_SHOULDER_LIFT
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": sl,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
"wrist_roll.pos": 0.0,
"gripper.pos": 60.0,
}
def _high_sweep_pose(pan: float) -> dict:
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
return {
"shoulder_pan.pos": pan,
"shoulder_lift.pos": shoulder_lift,
"elbow_flex.pos": elbow_flex,
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
"wrist_roll.pos": 0.0,
"gripper.pos": gripper,
}
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
"""Interpolate from current position to target at the given speed (units/s)."""
obs = robot.get_observation()
current = np.array([obs[k] for k in JOINT_NAMES])
goal = pose_to_array(target)
max_distance = float(np.max(np.abs(goal - current)))
if max_distance < 0.5:
return
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
dt = 1.0 / CONTROL_HZ
for step in range(1, n_steps + 1):
t = step / n_steps
robot.send_action(array_to_pose(current + t * (goal - current)))
precise_sleep(dt)
# ── Sequences ─────────────────────────────────────────────────────────────────
def grab_cube(robot: Robot) -> None:
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
for pan, end_pan in [
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
]:
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
move_to_pose(
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
)
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Extending to push cube into gripper...")
move_to_pose(
robot,
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
APPROACH_SPEED,
)
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
SWEEP_SPEED,
)
logger.info("Closing gripper...")
move_to_pose(
robot,
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
APPROACH_SPEED,
)
logger.info("Grab complete.")
def place_cube(robot: Robot) -> tuple[float, float]:
"""Carry the cube (gripper closed) to a random position on the left side, then release.
Returns:
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
"""
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
t = float(np.random.uniform(*PLACE_REACH_RANGE))
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
move_to_pose(
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
logger.info("Place complete.")
return pan, t
# ── Entry point ───────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
parser.add_argument("--port", default="/dev/ttyACM1")
parser.add_argument("--robot_id", default="omx_follower")
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
robot.connect(calibrate=True)
try:
if args.mode == "grab":
grab_cube(robot)
elif args.mode == "grab_and_place":
grab_cube(robot)
place_cube(robot)
finally:
robot.disconnect()
if __name__ == "__main__":
main()
+24 -21
View File
@@ -4,13 +4,13 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rewards.classifier.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -28,7 +28,7 @@ def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
policy_learner: GaussianActorPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
@@ -40,8 +40,9 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers_and_scheduler()
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -83,24 +84,26 @@ def run_learner(
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
def batch_iter(b=batch):
while True:
yield b
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats = algorithm.update(batch_iter())
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -113,7 +116,7 @@ def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
policy_actor: GaussianActorPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
@@ -144,15 +147,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
# Get action from policy (returns full action: continuous + discrete)
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action_tensor = policy_actor.select_action(policy_obs)
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
@@ -261,14 +264,14 @@ def main():
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
policy_cfg = GaussianActorConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
policy_actor = GaussianActorPolicy(policy_cfg)
policy_learner = GaussianActorPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
+29 -4
View File
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
dependencies = [
# Core ML
"torch>=2.7,<2.11.0",
"torchvision>=0.22.0,<0.26.0",
"torch>=2.7,<2.12.0",
"torchvision>=0.22.0,<0.27.0",
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0",
"Pillow>=10.0.0,<13.0.0",
@@ -99,7 +99,18 @@ dataset = [
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
# NOTE: torchcodec wheel availability matrix (PyPI):
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
"jsonlines>=4.0.0,<5.0.0",
]
training = [
@@ -195,7 +206,7 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -293,6 +304,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# ---------------- Tool Configurations ----------------
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
# uv pip install --force-reinstall torch torchvision \
# --index-url https://download.pytorch.org/whl/cu130
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
+1
View File
@@ -99,6 +99,7 @@ def save_checkpoint(
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
+3 -8
View File
@@ -256,7 +256,9 @@ class TrainPipelineConfig(HubMixin):
) from e
cli_args = kwargs.pop("cli_args", [])
if config_file is not None:
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
if config_file is not None and config_file.endswith(".json"):
with open(config_file) as f:
config = json.load(f)
migrated_config = _migrate_legacy_rabc_fields(config)
@@ -267,10 +269,3 @@ class TrainPipelineConfig(HubMixin):
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
+24
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Callable
from pathlib import Path
import numpy as np
@@ -189,6 +190,29 @@ class LeRobotDatasetMetadata:
if self.episodes is None:
self._load_metadata()
def filter_episodes(
self,
predicate: Callable[[dict], bool],
candidates: list[int] | None = None,
) -> list[int]:
"""Filter episodes whose metadata satisfies a given predicate.
Args:
predicate: Predicate over per-episode metadata rows used to select episodes.
candidates: Optional list of episode indices to restrict evaluation to.
Returns:
List of sorted episode indices that satisfy the predicate.
"""
self.ensure_readable()
if candidates is not None:
candidate_set = set(candidates)
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
else:
combined = predicate
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
return sorted(int(idx) for idx in filtered["episode_index"])
def _pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
+23 -1
View File
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
episode_filter: Callable[[dict], bool] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
``$HF_LEROBOT_HOME/hub``.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list. Defaults to None.
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
Defaults to None.
image_transforms (Callable | None, optional):
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
conversion. This works for both image-backed and video-backed observations and can later be
@@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self._video_backend = video_backend if video_backend else get_safe_default_codec()
@@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root = self.meta.root
self.revision = self.meta.revision
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
):
logger.warning(
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
)
if episode_filter is not None:
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
if not resolved:
raise ValueError(
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
)
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
episodes = resolved
self.episodes = episodes
# Create reader (hf_dataset loaded below)
self.reader = DatasetReader(
meta=self.meta,
+63 -56
View File
@@ -33,7 +33,6 @@ import fsspec
import numpy as np
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
@@ -132,7 +131,9 @@ def decode_video_frames(
video_path (Path): Path to the video file.
timestamps (list[float]): List of timestamps to extract frames.
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
accepted for one release as an alias for "pyav" and will be removed in a future version.
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
@@ -145,85 +146,87 @@ def decode_video_frames(
backend = get_safe_default_codec()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
)
elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
else:
raise ValueError(f"Unsupported video backend: {backend}")
def decode_video_frames_torchvision(
def decode_video_frames_pyav(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
log_loaded_timestamps: bool = False,
return_uint8: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
"""Loads frames associated to the requested timestamps of a video using PyAV.
The backend can be either "pyav" (default) or "video_reader".
"video_reader" requires installing torchvision from source, see:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
(note that you need to compile against ffmpeg<4.3)
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
x86_64 and linux armv7l see the torchcodec block in pyproject.toml for the full matrix).
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
accurate seek.
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
For more info on video decoding, see `benchmark/video/README.md`
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
forward until we have covered the requested timestamp range. The number of key frames in a
video can be adjusted at encoding time to trade off decoding speed against file size.
See torchvision doc for more info on these two backends:
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
Args:
video_path: Path to the video file.
timestamps: List of timestamps (in seconds) to extract frames for.
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
[0, 1] range.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
Returns:
torch.Tensor of shape (len(timestamps), C, H, W).
"""
video_path = str(video_path)
# set backend
keyframes_only = False
torchvision.set_video_backend(backend)
if backend == "pyav":
keyframes_only = True # pyav doesn't support accurate seek
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(video_path, "video")
video_path = str(video_path)
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts, keyframes_only=keyframes_only)
loaded_frames: list[torch.Tensor] = []
loaded_ts: list[float] = []
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
with av.open(video_path) as container:
stream = container.streams.video[0]
container.seek(int(first_ts * av.time_base), backward=True)
if backend == "pyav":
reader.container.close()
for frame in container.decode(stream):
if frame.pts is None:
continue
current_ts = float(frame.pts * stream.time_base)
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
# Convert to CHW uint8 to match torchcodec's output layout.
arr = frame.to_ndarray(format="rgb24") # H, W, 3
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
reader = None
if not loaded_frames:
raise FrameTimestampError(
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
)
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
loaded_ts_t = torch.tensor(loaded_ts)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
@@ -234,14 +237,14 @@ def decode_video_frames_torchvision(
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nloaded timestamps: {loaded_ts_t}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
f"\nbackend: pyav"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
closest_ts = loaded_ts_t[argmin_]
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
@@ -282,7 +285,11 @@ class VideoDecoderCache:
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
file_handle.close()
raise
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
+5 -5
View File
@@ -18,13 +18,13 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .sac.configuration_sac import SACConfig as SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
@@ -32,21 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
__all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"MultiTaskDiTConfig",
"EO1Config",
"PI0Config",
"PI0FastConfig",
"PI05Config",
"SACConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
horizon: int = 64
n_action_steps: int = 32
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = False
spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False
use_separate_rgb_encoder_per_camera: bool = True
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
+11 -11
View File
@@ -47,12 +47,12 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
@@ -88,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -127,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from .sac.modeling_sac import SACPolicy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
return SACPolicy
return GaussianActorPolicy
elif name == "smolvla":
from .smolvla.modeling_smolvla import SmolVLAPolicy
@@ -167,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -191,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "groot":
@@ -365,10 +365,10 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from .sac.processor_sac import make_sac_pre_post_processors
elif isinstance(policy_cfg, GaussianActorConfig):
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
processors = make_sac_pre_post_processors(
processors = make_gaussian_actor_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACConfig
from .modeling_sac import SACPolicy
from .processor_sac import make_sac_pre_post_processors
from .configuration_gaussian_actor import GaussianActorConfig
from .modeling_gaussian_actor import GaussianActorPolicy
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
@@ -75,18 +75,19 @@ class PolicyConfig:
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@PreTrainedConfig.register_subclass("gaussian_actor")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
class GaussianActorConfig(PreTrainedConfig):
"""Gaussian actor configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configures the policy-side (actor + observation encoder) of a Gaussian
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
By default the actor output is a tanh-squashed diagonal Gaussian
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
CLI: ``--policy.type=gaussian_actor``.
"""
# Mapping of feature types to normalization modes
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Training parameter
# Encoder architecture
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
# Number of steps for online training
online_steps: int = 1000000
# Capacity of the online replay buffer
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
# Network architecture
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters (Gaussian head)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
# Any validation specific to GaussianActor configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
# this preset.
default_lr = 3e-4
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
"actor": {"lr": default_lr},
"critic": {"lr": default_lr},
"temperature": {"lr": default_lr},
},
)
@@ -15,16 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_sac import SACConfig, is_image_feature
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
class GaussianActorPolicy(
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
config_class = GaussianActorConfig
name = "gaussian_actor"
def __init__(
self,
config: SACConfig | None = None,
config: GaussianActorConfig | None = None,
):
super().__init__(config)
config.validate_features()
@@ -54,9 +49,8 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_temperature()
self._init_discrete_critic()
def get_optim_params(self) -> dict:
optim_params = {
@@ -65,11 +59,7 @@ class SACPolicy(
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
return optim_params
def reset(self):
@@ -79,7 +69,9 @@ class SACPolicy(
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
raise NotImplementedError(
"GaussianActorPolicy does not support action chunking. It returns single actions!"
)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -92,360 +84,43 @@ class SACPolicy(
actions, _, _ = self.actor(batch, observations_features)
if self.config.num_discrete_actions is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
if self.discrete_critic is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
else:
discrete_action = torch.ones(
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
"""Actor forward pass: sample actions and return log-probabilities.
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
batch: A flat observation dict, or a training dict containing
``"state"`` (observations) and optionally ``"observation_feature"``
(pre-computed encoder features).
Returns:
Tensor of Q-values from all critics
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = GaussianActorObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
"""Initialize policy actor network."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
@@ -455,21 +130,25 @@ class SACPolicy(
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_discrete_critic(self) -> None:
"""Initialize discrete critic network."""
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
# TODO(Khalil): Compile the discrete critic
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class SACObservationEncoder(nn.Module):
class GaussianActorObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: GaussianActorConfig) -> None:
super().__init__()
self.config = config
self._init_image_layers()
@@ -677,84 +356,6 @@ class MLP(nn.Module):
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: SACObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class DiscreteCritic(nn.Module):
def __init__(
self,
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
class Policy(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
encoder: GaussianActorObservationEncoder,
network: nn.Module,
action_dim: int,
std_min: float = -5,
@@ -811,7 +412,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder: SACObservationEncoder = encoder
self.encoder: GaussianActorObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.std_min = std_min
@@ -885,7 +486,7 @@ class Policy(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
@@ -32,18 +32,18 @@ from lerobot.processor import (
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sac import SACConfig
from .configuration_gaussian_actor import GaussianActorConfig
def make_sac_pre_post_processors(
config: SACConfig,
def make_gaussian_actor_pre_post_processors(
config: GaussianActorConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
config: The configuration object for the tanh-Gaussian policy.
dataset_stats: A dictionary of statistics for normalization.
Returns:
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
pretrained_backbone_weights: str | None = None
use_group_norm: bool = True
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = False
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
n_vqvae_training_steps: int = 20000
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
target_dtype = torch.get_autocast_dtype(query_states.device.type)
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
+31 -11
View File
@@ -4,7 +4,6 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
@@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
This step normalizes the `transition` object by:
1. Copying `teleop_action` from `info` to `complementary_data`.
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
3. Copying `discrete_penalty` from `info` to `complementary_data`.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
if TELEOP_ACTION_KEY in info:
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
if DISCRETE_PENALTY_KEY in info:
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
if "is_intervention" in info:
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
@@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
Applies a small per-transition cost on the discrete gripper action.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Fires only when the commanded action would actually transition the gripper
from one extreme to the other (close-while-open or open-while-closed).
This discourages gripper oscillation while leaving "stay" and saturating-further
commands unpenalized.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
open_threshold: Normalized state below which the gripper is considered "open".
closed_threshold: Normalized state above which the gripper is considered "closed".
"""
penalty: float = -0.01
penalty: float = -0.02
max_gripper_pos: float = 30.0
open_threshold: float = 0.1
closed_threshold: float = 0.9
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
@@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep):
if raw_joint_positions is None:
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None)
if current_gripper_pos is None:
return new_transition
# Gripper action is a PolicyAction at this stage
# During reset, the transition may not carry any action yet.
if action is None:
return new_transition
# Gripper action is expected as the last action dimension.
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
@@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
# - currently open AND target is closed -> close transition
# - currently closed AND target is open -> open transition
is_open = gripper_state_normalized < self.open_threshold
is_closed = gripper_state_normalized > self.closed_threshold
cmd_close = gripper_action_normalized > self.closed_threshold
cmd_open = gripper_action_normalized < self.open_threshold
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
@@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
A dictionary containing the penalty value, max gripper position,
and the open/closed thresholds.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
"open_threshold": self.open_threshold,
"closed_threshold": self.closed_threshold,
}
def reset(self) -> None:
@@ -134,6 +134,24 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
(already ``(C, 1, 1)``). Needed by RL training, which can start without
a dataset and supplies stats manually via JSON config.
"""
for key, feature in self.features.items():
if feature.type != FeatureType.VISUAL:
continue
if key not in self._tensor_stats:
continue
for stat_name, stat_tensor in self._tensor_stats[key].items():
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
continue
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -152,6 +170,7 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -201,6 +220,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -211,6 +231,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
@@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
model_name: str = "lerobot/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
@@ -105,6 +105,7 @@ class Classifier(PreTrainedRewardModel):
def __init__(
self,
config: RewardClassifierConfig,
**kwargs,
):
from transformers import AutoModel
+26 -16
View File
@@ -12,23 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reinforcement learning modules.
"""Reinforcement learning modules.
Requires: ``pip install 'lerobot[hilserl]'``
Available modules (import directly)::
from lerobot.rl.actor import ...
from lerobot.rl.learner import ...
from lerobot.rl.learner_service import ...
from lerobot.rl.buffer import ...
from lerobot.rl.eval_policy import ...
from lerobot.rl.gym_manipulator import ...
Distributed actor / learner entry points (``actor``, ``learner``,
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
buffer, data sources and trainer are gRPC-free and usable standalone.
"""
from lerobot.utils.import_utils import require_package
from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
from .algorithms.factory import (
make_algorithm as make_algorithm,
make_algorithm_config as make_algorithm_config,
)
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
from .trainer import RLTrainer as RLTrainer
require_package("grpcio", extra="hilserl", import_name="grpc")
__all__: list[str] = []
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"make_algorithm",
"make_algorithm_config",
"SACAlgorithmConfig",
"RLTrainer",
"ReplayBuffer",
"DataMixer",
"OnlineOfflineMixer",
]
+113 -78
View File
@@ -49,39 +49,53 @@ https://github.com/michel-aractingi/lerobot-hilserl-guide
import logging
import os
import time
from collections.abc import Generator
from functools import lru_cache
from queue import Empty
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
bytes_to_state_dict = None
grpc_channel_options = None
python_object_to_bytes = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
transitions_to_bytes = None
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from torch.multiprocessing import Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
@@ -89,19 +103,24 @@ from lerobot.utils.utils import (
init_logging,
)
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
reset_and_build_transition,
step_env_and_process_transition,
)
from .queue import get_last_item_from_queue
from .train_rl import TrainRLServerPipelineConfig
# Main entry point
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
cfg.validate()
display_pid = False
if not use_threads(cfg):
@@ -212,7 +231,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
@@ -252,22 +271,24 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### To avoid sending a policy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
policy = policy.to(device).eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Build the algorithm
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -291,8 +312,17 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
# Unnormalize only the continuous part.
if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to(
device=continuous_action.device, dtype=continuous_action.dtype
)
action = torch.cat([continuous_action, discrete_action], dim=-1)
else:
action = postprocessor.process_action(action)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -326,7 +356,8 @@ def act_with_policy(
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
if is_intervention:
episode_intervention = True
episode_intervention_steps += 1
@@ -334,6 +365,7 @@ def act_with_policy(
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
TeleopEvents.IS_INTERVENTION.value: is_intervention,
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
@@ -354,7 +386,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -390,14 +422,7 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
transition = reset_and_build_transition(online_env, env_processor, action_processor)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
@@ -408,10 +433,10 @@ def act_with_policy(
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
stub: "services_pb2_grpc.LearnerServiceStub",
shutdown_event: Any, # Event
attempts: int = 30,
):
) -> bool:
"""Establish a connection with the learner.
Args:
@@ -441,12 +466,14 @@ def establish_learner_connection(
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""
Returns a client for the learner service.
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
"""Return a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
Returns:
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
"""
channel = grpc.insecure_channel(
@@ -461,16 +488,18 @@ def learner_service_client(
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
@@ -513,12 +542,11 @@ def receive_policy(
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send transitions to the learner.
This function continuously retrieves messages from the queue and processes:
@@ -526,6 +554,13 @@ def send_transitions(
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
transitions_queue (Queue): The queue to receive the transitions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -563,18 +598,24 @@ def send_transitions(
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
shutdown_event: Any, # Event
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
grpc_channel: "grpc.Channel | None" = None,
) -> None:
"""Send interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
interactions_queue (Queue): The queue to receive the interactions.
shutdown_event (Event): The event to check if the process should shutdown.
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
grpc_channel (grpc.Channel | None): Optional pre-created channel.
"""
if not use_threads(cfg):
@@ -613,7 +654,11 @@ def send_interactions(
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
def transitions_stream(
shutdown_event: Any, # Event
transitions_queue: Queue,
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
@@ -629,10 +674,10 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
def interactions_stream(
shutdown_event: Event,
shutdown_event: Any, # Event
interactions_queue: Queue,
timeout: float, # type: ignore
) -> services_pb2.Empty:
timeout: float,
) -> "Generator[Any, None, services_pb2.Empty]":
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
@@ -652,7 +697,8 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
@@ -667,18 +713,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
algorithm.load_weights(state_dicts, device=device)
# Utilities functions
+20
View File
@@ -0,0 +1,20 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sac import SACAlgorithm, SACAlgorithmConfig
__all__ = [
"SACAlgorithm",
"SACAlgorithmConfig",
]
+207
View File
@@ -0,0 +1,207 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import os
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from torch.optim import Optimizer
from lerobot.types import BatchType
from lerobot.utils.hub import HubMixin
from .configs import RLAlgorithmConfig, TrainingStats
if TYPE_CHECKING:
from torch import nn
from ..data_sources.data_mixer import DataMixer
T = TypeVar("T", bound="RLAlgorithm")
class RLAlgorithm(HubMixin, abc.ABC):
"""Base for all RL algorithms."""
config_class: type[RLAlgorithmConfig]
name: str
config: RLAlgorithmConfig
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One complete training step.
The algorithm calls ``next(batch_iterator)`` as many times as it
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
raise NotImplementedError
def configure_data_iterator(
self,
data_mixer: DataMixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
) -> Iterator[BatchType]:
"""Create the data iterator this algorithm needs.
The default implementation uses the standard ``data_mixer.get_iterator()``.
Algorithms that need specialised sampling should override this method.
"""
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
@abc.abstractmethod
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""Build and return the optimizers used during training.
Called once on the learner side after construction.
"""
raise NotImplementedError
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
return {}
@property
def optimization_step(self) -> int:
"""Current learner optimization step.
Part of the stable contract for checkpoint/resume. Algorithms can
either use this default storage or override for custom behavior.
"""
return getattr(self, "_optimization_step", 0)
@optimization_step.setter
def optimization_step(self, value: int) -> None:
self._optimization_step = int(value)
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors."""
return {}
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Must return a flat tensor mapping for everything the algorithm owns
that is not part of the policy (e.g. critic ensembles, target networks,
temperature parameters). Algorithms with no training-only tensors
should explicitly return an empty dict.
"""
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
Implementations MUST keep the identity of any ``nn.Parameter`` that an
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
rather than rebinding the attribute.
"""
raise NotImplementedError
def _save_pretrained(self, save_directory: Path) -> None:
"""Persist the algorithm's tensors and config to ``save_directory``.
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
"""
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
self.config._save_pretrained(save_directory)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
policy: nn.Module,
config: RLAlgorithmConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
device: str | torch.device = "cpu",
**algo_kwargs: Any,
) -> T:
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
if config is None:
config = cls.config_class.from_pretrained(
pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if hasattr(config, "policy_config"):
config.policy_config = policy.config
instance = cls(policy=policy, config=config, **algo_kwargs)
model_id = str(pretrained_name_or_path)
if os.path.isdir(model_id):
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
tensors = load_safetensors(model_file)
instance.load_state_dict(tensors, device=device)
return instance
+138
View File
@@ -0,0 +1,138 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="RLAlgorithmConfig")
logger = logging.getLogger(__name__)
@dataclass
class TrainingStats:
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
losses: dict[str, float] = field(default_factory=dict)
grad_norms: dict[str, float] = field(default_factory=dict)
extra: dict[str, float] = field(default_factory=dict)
def to_log_dict(self) -> dict[str, float]:
"""Flatten all stats into a single dict for logging."""
d: dict[str, float] = {}
for name, val in self.losses.items():
d[name] = val
for name, val in self.grad_norms.items():
d[f"{name}_grad_norm"] = val
for name, val in self.extra.items():
d[name] = val
return d
@dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""Registry for algorithm configs."""
@property
def type(self) -> str:
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
@classmethod
@abc.abstractmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
"""Build an algorithm config from a policy config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
def _save_pretrained(self, save_directory: Path) -> None:
"""Serialize this config as ``config.json`` inside ``save_directory``."""
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**algo_kwargs: Any,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with draccus.config_type("json"):
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
raise TypeError(
f"Config at {model_id} has type '{instance.type}' but was loaded via "
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
)
for key, value in algo_kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
return instance
+99
View File
@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from .base import RLAlgorithm
from .configs import RLAlgorithmConfig
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
Args:
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
**kwargs: Keyword arguments forwarded to the config class constructor.
Returns:
An instance of the matching ``RLAlgorithmConfig`` subclass.
Raises:
ValueError: If ``algorithm_type`` is not registered.
"""
try:
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
except KeyError as err:
raise ValueError(
f"Algorithm type '{algorithm_type}' is not registered. "
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
"""
Retrieves an RL algorithm class by its registered name.
This function uses dynamic imports to avoid loading all algorithm classes into
memory at once, improving startup time and reducing dependencies.
Args:
name: The name of the algorithm. Supported names are "sac".
Returns:
The algorithm class corresponding to the given name.
Raises:
ValueError: If the algorithm name is not recognized.
"""
if name == "sac":
from .sac.sac_algorithm import SACAlgorithm
return SACAlgorithm
raise ValueError(
f"Algorithm type '{name}' is not available. "
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
)
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
"""
Instantiate an RL algorithm.
This factory function looks up the :class:`RLAlgorithm` subclass that matches
``cfg.type`` and instantiates it with the provided policy. It also enforces
that ``cfg.policy_config`` has been populated before construction (this is
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
Args:
cfg: The algorithm configuration. Must have ``policy_config`` set.
policy: The policy module the algorithm will train.
Returns:
An instantiated :class:`RLAlgorithm`.
Raises:
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
registered.
"""
if getattr(cfg, "policy_config", None) is None:
raise ValueError(
f"{type(cfg).__name__}.policy_config is None. "
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
"before calling make_algorithm()."
)
cls = get_algorithm_class(cfg.type)
return cls(policy=policy, config=cfg)
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACAlgorithmConfig
from .sac_algorithm import SACAlgorithm
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
@@ -0,0 +1,99 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig,
GaussianActorConfig,
)
from ..configs import RLAlgorithmConfig
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
class SACAlgorithmConfig(RLAlgorithmConfig):
"""Soft Actor-Critic (SAC) algorithm configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum
entropy reinforcement learning framework. It learns a policy and a Q-function
simultaneously using experience collected from the environment.
This configuration class contains the algorithm-side hyperparameters: critic
ensemble, target networks, temperature / entropy tuning, and the Bellman
update loop. The policy-side (actor + observation encoder) lives in
:class:`~lerobot.policies.gaussian_actor.GaussianActorConfig` and is
referenced via :attr:`policy_config`.
"""
# Optimizer learning rates
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Bellman update
# Discount factor for the SAC algorithm
discount: float = 0.99
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Critic ensemble
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Temperature / entropy
# Initial temperature value
temperature_init: float = 1.0
# Target entropy for automatic temperature tuning. If ``None``, defaults to
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
# there is a discrete action head).
target_entropy: float | None = None
# Update loop
# Update-to-data ratio. Set to >1 to enable extra critic updates per env step.
utd_ratio: int = 1
# Frequency of policy updates
policy_update_freq: int = 1
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Optimizations
# torch.compile is currently disabled by default
use_torch_compile: bool = False
# Policy config
policy_config: PreTrainedConfig | None = None
@classmethod
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config with default hyperparameters for a given policy."""
return cls(
policy_config=policy_cfg,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
)
@@ -0,0 +1,672 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from collections.abc import Callable, Iterator
from dataclasses import asdict
from typing import Any
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
DISCRETE_DIMENSION_INDEX,
MLP,
DiscreteCritic,
GaussianActorObservationEncoder,
GaussianActorPolicy,
orthogonal_init,
)
from lerobot.policies.utils import get_device_from_parameters
from lerobot.types import BatchType
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import move_state_dict_to_device
from ..base import RLAlgorithm
from ..configs import TrainingStats
from .configuration_sac import SACAlgorithmConfig
class SACAlgorithm(RLAlgorithm):
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
config_class = SACAlgorithmConfig
name = "sac"
def __init__(
self,
policy: GaussianActorPolicy,
config: SACAlgorithmConfig,
):
self.config = config
self.policy_config = config.policy_config
self.policy = policy
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim)
self._init_temperature(action_dim)
self._device = torch.device(self.policy.config.device)
self._move_to_device()
def _init_critics(self, action_dim) -> None:
"""Build critic ensemble, targets."""
encoder = self.policy.encoder_critic
heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
target_heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
# TODO(Khalil): Investigate and fix torch.compile
# NOTE: torch.compile is disabled, policy does not converge when enabled.
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.discrete_critic_target = None
if self.policy_config.num_discrete_actions is not None:
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
"""Build target discrete critic (main network is owned by the policy)."""
discrete_critic_target = DiscreteCritic(
encoder=encoder,
input_dim=encoder.output_dim,
output_dim=self.policy_config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO(Khalil): Compile the discrete critic
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
return discrete_critic_target
def _init_temperature(self, continuous_action_dim: int) -> None:
"""Set up temperature parameter (log_alpha) and target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
total_action_dim = continuous_action_dim + (
1 if self.policy_config.num_discrete_actions is not None else 0
)
self.target_entropy = -total_action_dim / 2
def _move_to_device(self) -> None:
self.policy.to(self._device)
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if self.discrete_critic_target is not None:
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def _critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def _discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Run one SAC training step (critic / discrete-critic / actor / temperature).
Pulls ``utd_ratio`` batches from ``batch_iterator``, computes the relevant
losses, backpropagates each, and updates target networks.
Args:
batch_iterator: yields batches each containing
- ``action``: Action tensor
- ``reward``: Reward tensor
- ``state``: Observations tensor dict
- ``next_state``: Next observations tensor dict
- ``done``: Done mask tensor
- ``observation_feature``: Optional pre-computed observation features
- ``next_observation_feature``: Optional pre-computed next observation features
- ``complementary_info`` (optional): per-step extras like discrete penalties
Returns:
TrainingStats with per-component losses and grad norms.
"""
clip = self.config.grad_clip_norm
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
self.optimizers["critic"].step()
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
self.optimizers["discrete_critic"].step()
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
loss_critic = self._compute_loss_critic(fb)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
self.optimizers["critic"].step()
stats = TrainingStats(
losses={"loss_critic": loss_critic.item()},
grad_norms={"critic": critic_grad},
)
if self.policy_config.num_discrete_actions is not None:
loss_dc = self._compute_loss_discrete_critic(fb)
self.optimizers["discrete_critic"].zero_grad()
loss_dc.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(), max_norm=clip
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_dc.item()
stats.grad_norms["discrete_critic"] = dc_grad
if self._optimization_step % self.config.policy_update_freq == 0:
for _ in range(self.config.policy_update_freq):
loss_actor = self._compute_loss_actor(fb)
self.optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad = torch.nn.utils.clip_grad_norm_(
self.policy.actor.parameters(), max_norm=clip
).item()
self.optimizers["actor"].step()
loss_temp = self._compute_loss_temperature(fb)
self.optimizers["temperature"].zero_grad()
loss_temp.backward()
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
self.optimizers["temperature"].step()
stats.losses["loss_actor"] = loss_actor.item()
stats.losses["loss_temperature"] = loss_temp.item()
stats.grad_norms["actor"] = actor_grad
stats.grad_norms["temperature"] = temp_grad
stats.extra["temperature"] = self.temperature
self._update_target_networks()
self._optimization_step += 1
return stats
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
# Extract common components from batch
observations = batch["state"]
actions = batch[ACTION]
observation_features = batch.get("observation_feature")
# Extract critic-specific components
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
next_observation_features = batch.get("next_observation_feature")
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.policy.actor(
next_observations, next_observation_features
)
# 2- compute q targets
q_targets = self._critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.policy_config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self._critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
observation_features = batch.get("observation_feature")
next_observation_features = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self._discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self._discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self._discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
observation_features = batch.get("observation_feature")
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
q_preds = self._critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
"""Compute the temperature loss"""
observations = batch["state"]
observation_features = batch.get("observation_feature")
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.policy.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def _update_target_networks(self) -> None:
"""Update target networks with exponential moving average"""
for target_p, p in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
if self.policy_config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(
p.data * self.config.critic_target_update_weight
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
)
def _prepare_forward_batch(
self, batch: BatchType, *, include_complementary_info: bool = True
) -> dict[str, Any]:
observations = batch["state"]
next_observations = batch["next_state"]
observation_features, next_observation_features = self.get_observation_features(
observations, next_observations
)
forward_batch: dict[str, Any] = {
ACTION: batch[ACTION],
"reward": batch["reward"],
"state": observations,
"next_state": next_observations,
"done": batch["done"],
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
if include_complementary_info and "complementary_info" in batch:
forward_batch["complementary_info"] = batch["complementary_info"]
return forward_batch
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
A dictionary mapping component names ("actor", "critic", "temperature")
to their respective Adam optimizers.
"""
actor_params = self.policy.get_optim_params()["actor"]
self.optimizers = {
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
}
if self.policy_config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
def get_weights(self) -> dict[str, Any]:
"""Send actor + discrete-critic state dicts."""
state_dicts: dict[str, Any] = {
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
}
if self.policy_config.num_discrete_actions is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
self.policy.discrete_critic.state_dict(), device="cpu"
)
return state_dicts
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load actor + discrete-critic weights into the policy."""
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
self.policy.actor.load_state_dict(actor_sd)
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
self.policy.discrete_critic.load_state_dict(discrete_sd)
def state_dict(self) -> dict[str, torch.Tensor]:
"""Algorithm-owned trainable tensors.
Encoder weights are stripped because they are owned by the policy
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
"""
bundle: dict[str, torch.Tensor] = {}
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
bundle[f"critic_ensemble.{k}"] = v
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
bundle[f"critic_target.{k}"] = v
if self.discrete_critic_target is not None:
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
bundle[f"discrete_critic_target.{k}"] = v
bundle["log_alpha"] = self.log_alpha.detach()
return bundle
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
device: str | torch.device = "cpu",
) -> None:
"""In-place load of algorithm-owned tensors.
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
``temperature`` optimizer's reference to the parameter object stays
valid after resume.
"""
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
critic_target_state = _split_prefix(state_dict, "critic_target.")
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
self.critic_target.load_state_dict(critic_target_state, strict=False)
if self.discrete_critic_target is not None:
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
if "log_alpha" in state_dict:
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Drop ``encoder.*`` keys from a critic-module state dict."""
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: float | None = None,
init_final: float | None = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (GaussianActorObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: GaussianActorObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
+3 -3
View File
@@ -97,8 +97,8 @@ class ReplayBuffer:
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
state_keys (list[str]): The list of keys that appear in `state` and `next_state`.
image_augmentation_function (Callable | None): A function that takes a batch of images
and returns a batch of augmented images. If None, a default augmentation function is used.
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
@@ -634,7 +634,7 @@ class ReplayBuffer:
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
transitions (list[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
+2 -2
View File
@@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
crop_params_dict (dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
resize_size (tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.types import BatchType
from .data_mixer import DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
+97
View File
@@ -0,0 +1,97 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from lerobot.types import BatchType
from ..buffer import ReplayBuffer, concatenate_batch_transitions
class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies."""
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
raise NotImplementedError
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Infinite iterator that yields batches."""
while True:
yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an offline replay buffer."""
def __init__(
self,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer | None = None,
online_ratio: float = 1.0,
):
if not 0.0 <= online_ratio <= 1.0:
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
self.online_buffer = online_buffer
self.offline_buffer = offline_buffer
self.online_ratio = online_ratio
def sample(self, batch_size: int) -> BatchType:
if self.offline_buffer is None:
return self.online_buffer.sample(batch_size)
n_online = max(1, int(batch_size * self.online_ratio))
n_offline = batch_size - n_online
online_batch = self.online_buffer.sample(n_online)
offline_batch = self.offline_buffer.sample(n_offline)
return concatenate_batch_transitions(online_batch, offline_batch)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches by composing buffer async iterators."""
n_online = max(1, int(batch_size * self.online_ratio))
online_iter = self.online_buffer.get_iterator(
batch_size=n_online,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
if self.offline_buffer is None:
yield from online_iter
return
n_offline = batch_size - n_online
offline_iter = self.offline_buffer.get_iterator(
batch_size=n_offline,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
while True:
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
+1 -1
View File
@@ -17,7 +17,6 @@ import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_policy
from lerobot.robots import ( # noqa: F401
@@ -31,6 +30,7 @@ from lerobot.teleoperators import (
)
from .gym_manipulator import make_robot_env
from .train_rl import TrainRLServerPipelineConfig
logging.basicConfig(level=logging.INFO)
+108 -73
View File
@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.import_utils import require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -312,6 +313,7 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
# Check if this is a GymHIL simulation environment
if cfg.name == "gym_hil":
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
import gym_hil # noqa: F401
# Extract gripper settings with defaults
@@ -383,10 +385,21 @@ def make_processors(
GymHILAdapterProcessorStep(),
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
env_pipeline_steps.extend(
[
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
)
return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
), DataProcessorPipeline(
@@ -551,8 +564,19 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
# Merge env and action-processor info: env wins for str keys, action-processor
# wins for `TeleopEvents` enum keys
action_info = processed_action_transition[TransitionKey.INFO]
new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO])
for key, value in action_info.items():
if isinstance(key, TeleopEvents):
new_info[key] = value
new_transition = create_transition(
observation=obs,
@@ -568,6 +592,24 @@ def step_env_and_process_transition(
return new_transition
def reset_and_build_transition(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
) -> EnvTransition:
"""Reset env + processors and return the first env-processed transition."""
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
complementary_data: dict[str, Any] = {}
if hasattr(env, "get_raw_joint_positions"):
raw_joint_positions = env.get_raw_joint_positions()
if raw_joint_positions is not None:
complementary_data["raw_joint_positions"] = raw_joint_positions
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
return env_processor(data=transition)
def control_loop(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
@@ -593,17 +635,7 @@ def control_loop(
print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit")
# Reset environment and processors
obs, info = env.reset()
complementary_data = (
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
)
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(data=transition)
transition = reset_and_build_transition(env, env_processor, action_processor)
# Determine if gripper is used
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
@@ -659,79 +691,82 @@ def control_loop(
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
try:
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {
observation = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
# Use teleop_action if available, otherwise use the action from the transition
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
frame = {
**observations,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
if cfg.mode == "record":
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
)
frame = {
**observation,
ACTION: action_to_record.cpu(),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"discrete_penalty", 0.0
)
frame["complementary_info.discrete_penalty"] = np.array(
[discrete_penalty], dtype=np.float32
)
episode_step += 1
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
episode_step += 1
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
# Reset for new episode
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
if dataset is not None:
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# Reset for new episode
transition = reset_and_build_transition(env, env_processor, action_processor)
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
# Maintain fps timing
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
finally:
if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None:
logging.info("Waiting for image writer to finish...")
dataset.writer.image_writer.stop()
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Finalizing dataset before pushing to hub")
+123 -309
View File
@@ -51,9 +51,21 @@ import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import _grpc_available, require_package
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2_grpc
else:
grpc = None
services_pb2_grpc = None
import grpc
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_file as load_safetensors
from termcolor import colored
from torch import nn
from torch.multiprocessing import Queue
@@ -68,14 +80,11 @@ from lerobot.common.train_utils import (
)
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies import make_policy, make_pre_post_processors
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
@@ -84,26 +93,35 @@ from lerobot.transport.utils import (
)
from lerobot.utils.constants import (
ACTION,
ALGORITHM_DIR,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.random_utils import set_seed
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
init_logging,
)
from .buffer import ReplayBuffer, concatenate_batch_transitions
from .algorithms.base import RLAlgorithm
from .algorithms.factory import make_algorithm
from .buffer import ReplayBuffer
from .data_sources import OnlineOfflineMixer
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
from .train_rl import TrainRLServerPipelineConfig
from .trainer import RLTrainer
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
require_package("grpcio", extra="hilserl", import_name="grpc")
if not use_threads(cfg):
import torch.multiprocessing as mp
@@ -179,7 +197,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
def start_learner_threads(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
) -> None:
"""
Start the learner threads for training.
@@ -253,7 +271,7 @@ def start_learner_threads(
def add_actor_information_and_train(
cfg: TrainRLServerPipelineConfig,
wandb_logger: WandBLogger | None,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
@@ -266,8 +284,8 @@ def add_actor_information_and_train(
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Delegates training updates to an ``RLAlgorithm``.
- Periodically pushes updated weights to actors.
- Logs training statistics, including loss values and optimization frequency.
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
@@ -286,17 +304,13 @@ def add_actor_information_and_train(
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -308,7 +322,7 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
@@ -317,15 +331,17 @@ def add_actor_information_and_train(
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
# Push initial policy weights to actors
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
@@ -338,21 +354,37 @@ def add_actor_information_and_train(
device=device,
storage_device=storage_device,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# DataMixer: online-only or online/offline 50-50 mix
data_mixer = OnlineOfflineMixer(
online_buffer=replay_buffer,
offline_buffer=offline_replay_buffer,
online_ratio=cfg.online_ratio,
)
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
preprocessor=preprocessor,
)
# If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
)
logging.info("Starting learner thread")
interaction_message = None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
algorithm.optimization_step = optimization_step
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -365,7 +397,6 @@ def add_actor_information_and_train(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
device=device,
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
@@ -382,180 +413,20 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["discrete_critic"].step()
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
critic_output = policy.forward(forward_batch, model="critic")
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["discrete_critic"].step()
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
stats = trainer.training_step()
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
last_time_policy_pushed = time.time()
# Update target networks (main and discrete)
policy.update_target_networks()
training_infos = stats.to_log_dict()
# Log training metrics at specified intervals
optimization_step = algorithm.optimization_step
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
@@ -583,7 +454,6 @@ def add_actor_information_and_train(
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
@@ -597,9 +467,12 @@ def add_actor_information_and_train(
policy=policy,
optimizers=optimizers,
replay_buffer=replay_buffer,
algorithm=algorithm,
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
@@ -607,7 +480,7 @@ def start_learner(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
shutdown_event: Any, # Event
cfg: TrainRLServerPipelineConfig,
):
"""
@@ -681,9 +554,12 @@ def save_training_checkpoint(
policy: nn.Module,
optimizers: dict[str, Optimizer],
replay_buffer: ReplayBuffer,
algorithm: RLAlgorithm | None = None,
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
preprocessor=None,
postprocessor=None,
) -> None:
"""
Save training checkpoint and associated data.
@@ -707,6 +583,8 @@ def save_training_checkpoint(
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
preprocessor: Optional preprocessor pipeline to save
postprocessor: Optional postprocessor pipeline to save
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
@@ -715,7 +593,7 @@ def save_training_checkpoint(
# Create checkpoint directory
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
# Save checkpoint
# Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=optimization_step,
@@ -723,13 +601,22 @@ def save_training_checkpoint(
policy=policy,
optimizer=optimizers,
scheduler=None,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save interaction step manually
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
os.makedirs(training_state_dir, exist_ok=True)
training_state = {"step": optimization_step, "interaction_step": interaction_step}
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
# Algorithm-owned tensors live in their own component subfolder
# so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
if algorithm is not None:
algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
# Enrich training_step.json with the RL-specific interaction_step counter so
# both can be restored from a single file.
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
write_json(
{"step": optimization_step, "interaction_step": interaction_step},
training_state_dir / TRAINING_STEP,
)
# Update the "last" symlink
update_last_checkpoint(checkpoint_dir)
@@ -760,58 +647,6 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler
# Training setup functions
@@ -875,13 +710,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
def load_training_state(
cfg: TrainRLServerPipelineConfig,
optimizers: Optimizer | dict[str, Optimizer],
algorithm: RLAlgorithm | None = None,
device: str | torch.device = "cpu",
):
"""
Loads the training state (optimizers, step count, etc.) from a checkpoint.
Loads the training state (optimizers, RNG, step + interaction step, and
algorithm-owned tensors) from the most recent checkpoint.
Args:
cfg (TrainRLServerPipelineConfig): Training configuration
optimizers (Optimizer | dict): Optimizers to load state into
cfg: Training configuration.
optimizers: Optimizers to load state into.
algorithm: Algorithm whose state dict should be restored.
Required for full main-equivalent resume;
the policy itself is restored separately via ``make_policy``.
device: Device on which to place loaded algorithm tensors.
Returns:
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
@@ -890,20 +732,31 @@ def load_training_state(
return None, None
# Construct path to the last checkpoint directory
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
logging.info(f"Loading training state from {checkpoint_dir}")
try:
# Use the utility function from train_utils which loads the optimizer state
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
# Restore optimizers + RNG + step from the standard `training_state/` folder
step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
# Load interaction step separately from training_state.pt
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
interaction_step = 0
if os.path.exists(training_state_path):
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
interaction_step = training_state.get("interaction_step", 0)
# Restore algorithm-owned tensors
if algorithm is not None:
algo_dir = checkpoint_dir / ALGORITHM_DIR
if algo_dir.is_dir():
tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
algorithm.load_state_dict(tensors, device=device)
logging.info(f"Loaded algorithm state from {algo_dir}")
else:
logging.warning(
f"No algorithm state found at {algo_dir}; "
"will keep their freshly-initialised values. Adam moments restored from the "
"old optimizer state may not match these reset parameters."
)
# Read interaction_step from the enriched training_step.json
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
return step, interaction_step
@@ -1016,33 +869,6 @@ def initialize_offline_replay_buffer(
# Utilities/Helpers functions
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
@@ -1093,19 +919,11 @@ def check_nan_in_transition(
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None:
logging.debug("[LEARNER] Pushing actor policy to the queue")
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_dicts = algorithm.get_weights()
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
@@ -1129,9 +947,8 @@ def process_transitions(
transition_queue: Queue,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
device: str,
dataset_repo_id: str | None,
shutdown_event: any,
shutdown_event: Any, # Event
):
"""Process all available transitions from the queue.
@@ -1139,7 +956,6 @@ def process_transitions(
transition_queue: Queue for receiving transitions from the actor
replay_buffer: Replay buffer to add transitions to
offline_replay_buffer: Offline replay buffer to add transitions to
device: Device to move transitions to
dataset_repo_id: Repository ID for dataset
shutdown_event: Event to signal shutdown
"""
@@ -1148,8 +964,6 @@ def process_transitions(
transition_list = bytes_to_transitions(buffer=transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition=transition, device=device)
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
@@ -1163,7 +977,7 @@ def process_transitions(
# Add to offline buffer if it's an intervention
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
TeleopEvents.IS_INTERVENTION
TeleopEvents.IS_INTERVENTION.value
):
offline_replay_buffer.add(**transition)
@@ -1172,7 +986,7 @@ def process_interaction_messages(
interaction_message_queue: Queue,
interaction_step_shift: int,
wandb_logger: WandBLogger | None,
shutdown_event: any,
shutdown_event: Any, # Event
) -> dict | None:
"""Process all available interaction messages from the queue.
+24 -7
View File
@@ -18,17 +18,32 @@
import logging
import time
from multiprocessing import Event, Queue
from typing import TYPE_CHECKING
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.import_utils import _grpc_available
from .queue import get_last_item_from_queue
if TYPE_CHECKING or _grpc_available:
import grpc
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
else:
grpc = None
services_pb2 = None
services_pb2_grpc = None
receive_bytes_in_chunks = None
send_bytes_in_chunks = None
_ServicerBase = object
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
class LearnerService(_ServicerBase):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
@@ -51,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802
def StreamParameters( # noqa: N802
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
@@ -86,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
@@ -100,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
@@ -114,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
return services_pb2.Empty()
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Top-level pipeline config for distributed RL training (actor / learner)."""
from __future__ import annotations
from dataclasses import dataclass
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from .algorithms.configs import RLAlgorithmConfig
from .algorithms.factory import make_algorithm_config
from .algorithms.sac import SACAlgorithmConfig # noqa: F401
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm config.
algorithm: RLAlgorithmConfig | None = None
# Data mixer strategy name. Currently supports "online_offline".
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer.
online_ratio: float = 0.5
def validate(self) -> None:
super().validate()
if self.algorithm is None:
self.algorithm = make_algorithm_config("sac")
if getattr(self.algorithm, "policy_config", None) is None:
self.algorithm.policy_config = self.policy
+101
View File
@@ -0,0 +1,101 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
from lerobot.types import BatchType
from .algorithms.base import RLAlgorithm
from .algorithms.configs import TrainingStats
from .data_sources.data_mixer import DataMixer
class RLTrainer:
"""Unified training step orchestrator.
Holds the algorithm, a DataMixer, and an optional preprocessor.
"""
def __init__(
self,
algorithm: RLAlgorithm,
data_mixer: DataMixer,
batch_size: int,
*,
preprocessor: Any | None = None,
):
self.algorithm = algorithm
self.data_mixer = data_mixer
self.batch_size = batch_size
self._preprocessor = preprocessor
self._iterator: Iterator[BatchType] | None = None
self.algorithm.make_optimizers_and_scheduler()
def _build_data_iterator(self) -> Iterator[BatchType]:
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
raw = self.algorithm.configure_data_iterator(
data_mixer=self.data_mixer,
batch_size=self.batch_size,
)
if self._preprocessor is not None:
return _PreprocessedIterator(raw, self._preprocessor)
return raw
def reset_data_iterator(self) -> None:
"""Discard the current iterator so it will be rebuilt lazily next step."""
self._iterator = None
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
"""Swap the active data mixer, optionally resetting the iterator."""
self.data_mixer = data_mixer
if reset:
self.reset_data_iterator()
def training_step(self) -> TrainingStats:
"""Run one training step (algorithm-agnostic)."""
if self._iterator is None:
self._iterator = self._build_data_iterator()
return self.algorithm.update(self._iterator)
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
"""Apply policy preprocessing to RL observations only."""
observations = batch["state"]
next_observations = batch["next_state"]
batch["state"] = preprocessor.process_observation(observations)
batch["next_state"] = preprocessor.process_observation(next_observations)
return batch
class _PreprocessedIterator:
"""Iterator wrapper that preprocesses each sampled RL batch."""
__slots__ = ("_raw", "_preprocessor")
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
self._raw = raw_iterator
self._preprocessor = preprocessor
def __iter__(self) -> _PreprocessedIterator:
return self
def __next__(self) -> BatchType:
batch = next(self._raw)
return preprocess_rl_batch(self._preprocessor, batch)
+1 -1
View File
@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
# Set to `True` for backward compatibility with previous policies/dataset
use_degrees: bool = False
use_degrees: bool = True
@dataclass
@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
discrete_gripper: If True, interpret the input as a discrete class index
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
"""
speed_factor: float = 20.0
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
raise ValueError("Joints observation is require for computing robot kinematics")
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_vel = (gripper_vel - 1) * self.clip_max
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
# Negation accounts for SO100 sign (joint position increases on close).
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
gripper_vel = -(gripper_vel - 1) * self.clip_max
# Compute desired gripper position
delta = gripper_vel * float(self.speed_factor)
+3 -4
View File
@@ -23,7 +23,6 @@ from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from .core import RolloutStrategy, send_next_action
from .display import BaseDisplay
logger = logging.getLogger(__name__)
@@ -39,8 +38,6 @@ class BaseStrategy(RolloutStrategy):
"""Initialise the inference engine."""
self._init_engine(ctx)
logger.info("Base strategy ready")
self._display = BaseDisplay(duration=ctx.runtime.cfg.duration)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous control loop until shutdown or duration expires."""
@@ -75,7 +72,9 @@ class BaseStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
self._warn_slow_loop(dt, control_interval, cfg.fps)
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference."""
-12
View File
@@ -33,7 +33,6 @@ from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
from .display import RolloutStatusDisplay
logger = logging.getLogger(__name__)
@@ -52,17 +51,6 @@ class RolloutStrategy(abc.ABC):
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
self._cached_obs_processed: dict | None = None
self._display: RolloutStatusDisplay | None = None
def _warn_slow_loop(self, dt: float, control_interval: float, fps: float) -> None:
"""Warn when the control loop runs slower than the target FPS."""
if dt > control_interval:
logger.warning(
"Control loop running slower (%.1f Hz) than target (%.0f Hz). "
"Possible causes: camera FPS not keeping up, slow policy inference, CPU starvation.",
1 / dt,
fps,
)
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference engine and action interpolator, then start the backend.
+7 -30
View File
@@ -71,7 +71,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .display import DAggerDisplay
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -287,7 +286,7 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.debug(
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
@@ -371,28 +370,6 @@ class DAggerStrategy(RolloutStrategy):
self._episode_duration_s,
)
if self.config.input_device == "keyboard":
kb = self.config.keyboard
pause_key, correction_key, upload_key = (
kb.pause_resume.upper(),
kb.correction.upper(),
kb.upload.upper(),
)
else:
pb = self.config.pedal
pause_key, correction_key, upload_key = pb.pause_resume, pb.correction, pb.upload
self._display = DAggerDisplay(
record_autonomous=self.config.record_autonomous,
num_episodes=self.config.num_episodes,
episode_duration_s=self._episode_duration_s,
input_device=self.config.input_device,
pause_key=pause_key,
correction_key=correction_key,
upload_key=upload_key,
)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run DAgger episodes with human-in-the-loop intervention."""
if self.config.record_autonomous:
@@ -465,7 +442,6 @@ class DAggerStrategy(RolloutStrategy):
interpolator.reset()
events.reset()
engine.resume()
self._display.show_state(DAggerPhase.AUTONOMOUS)
last_action: dict[str, Any] | None = None
record_tick = 0
@@ -496,7 +472,6 @@ class DAggerStrategy(RolloutStrategy):
ctx,
last_action,
)
self._display.show_state(new_phase)
if new_phase == DAggerPhase.AUTONOMOUS:
last_action = None
@@ -581,7 +556,9 @@ class DAggerStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
self._warn_slow_loop(dt, control_interval, cfg.fps)
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
@@ -622,7 +599,6 @@ class DAggerStrategy(RolloutStrategy):
interpolator.reset()
events.reset()
engine.resume()
self._display.show_state(DAggerPhase.AUTONOMOUS)
last_action: dict[str, Any] | None = None
start_time = time.perf_counter()
@@ -657,7 +633,6 @@ class DAggerStrategy(RolloutStrategy):
ctx,
last_action,
)
self._display.show_state(new_phase)
if new_phase == DAggerPhase.AUTONOMOUS:
last_action = None
@@ -730,7 +705,9 @@ class DAggerStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
self._warn_slow_loop(dt, control_interval, cfg.fps)
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
-263
View File
@@ -1,263 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Console status display for rollout strategies.
One subclass per strategy static states/controls are declared as class
constants; runtime-dependent values are passed to ``__init__``.
In each strategy's ``setup()``:
self._display = DAggerDisplay(
record_autonomous=self.config.record_autonomous,
num_episodes=self.config.num_episodes,
episode_duration_s=self._episode_duration_s,
input_device=self.config.input_device,
pause_key="SPACE",
correction_key="TAB",
upload_key="ENTER",
)
self._display.show_banner()
On each state transition:
self._display.show_state("correcting")
"""
from __future__ import annotations
import enum
import sys
from dataclasses import dataclass
def _supports_color() -> bool:
return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
class _C:
"""ANSI escape codes."""
RESET = "\033[0m"
BOLD = "\033[1m"
DIM = "\033[2m"
GREEN = "\033[1;92m"
YELLOW = "\033[1;93m"
RED = "\033[1;91m"
CYAN = "\033[1;96m"
WHITE = "\033[1;97m"
GRAY = "\033[2;37m"
@dataclass
class StateConfig:
"""One named rollout state.
``key`` must match the string passed to ``RolloutStatusDisplay.show_state()``.
"""
key: str
emoji: str
label: str
description: str
color: str = _C.WHITE
@dataclass
class ControlConfig:
"""One keyboard/pedal binding shown in the startup banner."""
key: str
description: str
# ---------------------------------------------------------------------------
# Base display class
# ---------------------------------------------------------------------------
class RolloutStatusDisplay:
"""Unified console status display. Subclass once per strategy."""
def __init__(
self,
strategy: str,
states: list[StateConfig],
controls: list[ControlConfig],
info: list[str] | None = None,
) -> None:
self.strategy = strategy
self._states = {s.key: s for s in states}
self._controls = controls
self._info = info or []
self._use_color = _supports_color()
def _c(self, code: str, text: str) -> str:
if not self._use_color:
return text
return f"{code}{text}{_C.RESET}"
def show_banner(self) -> None:
"""Print startup banner: strategy name, states, controls, config info."""
width = 62
sep = self._c(_C.BOLD, "" * width)
print(f"\n{sep}")
print(self._c(_C.BOLD, f" lerobot-rollout │ {self.strategy}"))
if self._states:
print()
for state in self._states.values():
label = self._c(state.color, f"{state.label:<14}")
desc = self._c(_C.GRAY, state.description)
print(f" {state.emoji} {label} {desc}")
if self._controls:
print()
key_width = max(len(c.key) for c in self._controls)
for ctrl in self._controls:
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
print(f" {key_str} {ctrl.description}")
if self._info:
print()
for item in self._info:
print(f" {item}")
print(f"{sep}\n")
def show_state(self, state_key: str | enum.Enum) -> None:
"""Print the current state and available controls - call this on every transition."""
key = state_key.value if isinstance(state_key, enum.Enum) else state_key
state = self._states.get(key)
if state is None:
return
label = self._c(state.color, f"{state.label:<14}")
desc = self._c(_C.GRAY, state.description)
print(f"\n {state.emoji} {label} {desc}\n")
if self._controls:
key_width = max(len(c.key) for c in self._controls)
for ctrl in self._controls:
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
print(f" {key_str} {ctrl.description}")
print()
# ---------------------------------------------------------------------------
# One display subclass per strategy
# ---------------------------------------------------------------------------
class BaseDisplay(RolloutStatusDisplay):
"""Status display for the base (eval-only, no recording) strategy."""
_STATES = [StateConfig("running", "🟢", "RUNNING", "autonomous rollout — no recording", _C.GREEN)]
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
def __init__(self, duration: float = 0) -> None:
info = ["No recording — evaluation only."]
if duration > 0:
info.append(f"Duration: {duration:.0f}s")
super().__init__("base", self._STATES, self._CONTROLS, info)
class SentryDisplay(RolloutStatusDisplay):
"""Status display for the sentry (continuous autonomous recording) strategy."""
_STATES = [StateConfig("recording", "🟢", "RECORDING", "continuous autonomous recording", _C.GREEN)]
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
def __init__(self, episode_duration_s: float, upload_every_n_episodes: int) -> None:
info = [
f"Episode rotation: ~{episode_duration_s:.0f}s | "
f"Upload every {upload_every_n_episodes} episodes",
]
super().__init__("sentry", self._STATES, self._CONTROLS, info)
class HighlightDisplay(RolloutStatusDisplay):
"""Status display for the highlight (ring-buffer on-demand save) strategy."""
def __init__(self, ring_buffer_seconds: float, save_key: str, push_key: str) -> None:
states = [
StateConfig(
"buffering",
"",
"BUFFERING",
f"ring buffer active — last {ring_buffer_seconds:.0f}s captured",
_C.WHITE,
),
StateConfig("recording", "🔴", "RECORDING", "live recording — press [s] to save episode", _C.RED),
]
controls = [
ControlConfig(save_key, "BUFFERING ↔ RECORDING start recording / save episode"),
ControlConfig(push_key, "push dataset to Hub (background)"),
ControlConfig("ESC", "stop session"),
]
super().__init__("highlight", states, controls)
class DAggerDisplay(RolloutStatusDisplay):
"""Status display for the dagger (human-in-the-loop) strategy."""
_PAUSED_STATE = StateConfig("paused", "🟡", "PAUSED", "holding last position — awaiting input", _C.YELLOW)
_CORRECTING_STATE = StateConfig(
"correcting", "🔴", "CORRECTING", "human teleop active — recording correction", _C.RED
)
def __init__(
self,
record_autonomous: bool,
num_episodes: int,
episode_duration_s: float,
input_device: str,
pause_key: str,
correction_key: str,
upload_key: str,
) -> None:
mode = "continuous recording" if record_autonomous else "corrections only"
auto_desc = "policy running — recording" if record_autonomous else "policy running — no recording"
states = [
StateConfig("autonomous", "🟢", "AUTONOMOUS", auto_desc, _C.GREEN),
self._PAUSED_STATE,
self._CORRECTING_STATE,
]
controls = [
ControlConfig(pause_key, "AUTONOMOUS ↔ PAUSED pause / resume policy"),
ControlConfig(correction_key, "PAUSED ↔ CORRECTING start / stop correction"),
ControlConfig(upload_key, "push dataset to Hub"),
ControlConfig("ESC", "stop session"),
]
info = [f"Target: {num_episodes} episodes | Input: {input_device}"]
if record_autonomous:
info.append(f"Episode rotation: ~{episode_duration_s:.0f}s")
super().__init__(f"dagger [{mode}]", states, controls, info)
if __name__ == "__main__":
dagger_display = DAggerDisplay(
record_autonomous=False,
num_episodes=20,
episode_duration_s=30,
input_device="keyboard",
pause_key="SPACE",
correction_key="TAB",
upload_key="ENTER",
)
dagger_display.show_banner()
dagger_display.show_state("paused")
dagger_display.show_state("correcting")
dagger_display.show_state("paused")
dagger_display.show_state("autonomous")
+4 -20
View File
@@ -17,7 +17,6 @@
from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
@@ -37,7 +36,6 @@ from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
from .display import HighlightDisplay
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -55,13 +53,6 @@ if PYNPUT_AVAILABLE:
logger = logging.getLogger(__name__)
class HighlightPhase(enum.Enum):
"""Observable phases of a Highlight session."""
BUFFERING = "buffering" # Ring buffer accumulating frames, not recording
RECORDING = "recording" # Live recording active
class HighlightStrategy(RolloutStrategy):
"""Autonomous rollout with on-demand recording via ring buffer.
@@ -114,13 +105,6 @@ class HighlightStrategy(RolloutStrategy):
self.config.save_key,
self.config.push_key,
)
self._display = HighlightDisplay(
ring_buffer_seconds=self.config.ring_buffer_seconds,
save_key=self.config.save_key,
push_key=self.config.push_key,
)
self._display.show_banner()
self._display.show_state(HighlightPhase.BUFFERING)
def run(self, ctx: RolloutContext) -> None:
"""Run the autonomous loop, buffering frames and recording on demand."""
@@ -178,7 +162,6 @@ class HighlightStrategy(RolloutStrategy):
for buffered_frame in ring.drain():
dataset.add_frame(buffered_frame)
self._recording_live.set()
self._display.show_state(HighlightPhase.RECORDING)
else:
dataset.add_frame(frame)
with self._episode_lock:
@@ -189,7 +172,6 @@ class HighlightStrategy(RolloutStrategy):
play_sounds,
)
self._recording_live.clear()
self._display.show_state(HighlightPhase.BUFFERING)
continue # frame already consumed — skip ring.append
if self._push_requested.is_set():
@@ -206,7 +188,9 @@ class HighlightStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
self._warn_slow_loop(dt, control_interval, cfg.fps)
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("Highlight control loop ended")
@@ -271,7 +255,7 @@ class HighlightStrategy(RolloutStrategy):
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.debug("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
+3 -7
View File
@@ -32,7 +32,6 @@ from lerobot.utils.utils import log_say
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .display import SentryDisplay
logger = logging.getLogger(__name__)
@@ -80,11 +79,6 @@ class SentryStrategy(RolloutStrategy):
self._episode_duration_s,
self.config.upload_every_n_episodes,
)
self._display = SentryDisplay(
episode_duration_s=self._episode_duration_s,
upload_every_n_episodes=self.config.upload_every_n_episodes,
)
self._display.show_banner()
def run(self, ctx: RolloutContext) -> None:
"""Run the continuous recording loop with automatic episode rotation."""
@@ -166,7 +160,9 @@ class SentryStrategy(RolloutStrategy):
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
self._warn_slow_loop(dt, control_interval, cfg.fps)
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("Sentry control loop ended — saving final episode")
+1
View File
@@ -92,6 +92,7 @@ def get_sys_info() -> dict[str, str]:
info.update(
{
"PyTorch version": torch_version,
"Torchcodec version": get_package_version("torchcodec"),
"Is PyTorch built with CUDA support?": str(torch_cuda_available),
"Cuda version": cuda_version,
"GPU model": gpu_model,
@@ -104,11 +104,14 @@ class KeyboardTeleop(Teleoperator):
def _on_press(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, True))
key = key.char
self.event_queue.put((key, True))
def _on_release(self, key):
if hasattr(key, "char"):
self.event_queue.put((key.char, False))
key = key.char
self.event_queue.put((key, False))
if key == keyboard.Key.esc:
logging.info("ESC pressed, disconnecting.")
self.disconnect()
@@ -204,8 +207,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
# this is useful for retrieving other events like interventions for RL, episode success, etc.
self.misc_keys_queue.put(key)
self.current_pressed.clear()
action_dict = {
"delta_x": delta_x,
"delta_y": delta_y,
@@ -256,6 +257,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
]
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
self.current_pressed.clear()
# Check for episode control commands from misc_keys_queue
terminate_episode = False
success = False
@@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "sac" %}
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
{% elif model_name == "gaussian_actor" %}
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
{% elif model_name == "reward_classifier" %}
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% else %}
+1
View File
@@ -40,6 +40,7 @@ PolicyAction = torch.Tensor
RobotAction = dict[str, Any]
EnvAction = np.ndarray
RobotObservation = dict[str, Any]
BatchType = dict[str, Any]
EnvTransition = TypedDict(
+1
View File
@@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
ALGORITHM_DIR = "algorithm"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
+1
View File
@@ -132,6 +132,7 @@ _faker_available = is_package_available("faker")
_pynput_available = is_package_available("pynput")
_pygame_available = is_package_available("pygame")
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
_grpc_available = is_package_available("grpcio", import_name="grpc")
_wallx_deps_available = (
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
)
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
oid sha256:51effd76b73e972f10d31f5084ab906386134b600c87b2668767d30232a902bd
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
size 47424
oid sha256:d4d7a16ca67f9adefac0e0620a7b2e9c822f2db42faaaced7a89fbad60e5ead4
size 47680
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
oid sha256:796c439ee8a64bf9901ff8325e7419bda8bd316360ee95e6304e8e1ae0f4c36c
size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
size 47424
oid sha256:ad33a8b47c39c2e1374567ff9da43cdb95e2dbe904c1b02a35051346d3043095
size 47680
+65
View File
@@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
threshold = sorted(lengths)[len(lengths) // 2]
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
expected_frames = sum(lengths[i] for i in expected_eps)
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
assert filtered.num_frames == expected_frames
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
lengths = dataset.meta.episodes["length"]
candidates = [0, 2, 4, 6]
candidate_lengths = [lengths[i] for i in candidates]
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
expected_eps = [i for i in candidates if lengths[i] >= threshold]
filtered = LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episodes=candidates,
episode_filter=lambda ep: ep["length"] >= threshold,
)
assert filtered.num_episodes == len(expected_eps)
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
assert seen_eps == set(expected_eps)
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
with pytest.raises(ValueError, match=r"The episode filter did not match any episode"):
LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["length"] < 0,
)
def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
"""A predicate referencing a column absent from meta.episodes surfaces a clear KeyError."""
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
with pytest.raises(KeyError, match="not_a_real_field"):
LeRobotDataset(
dataset.repo_id,
root=dataset.root,
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
)
@@ -17,19 +17,19 @@
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
GaussianActorConfig,
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
config = SACConfig()
def test_gaussian_actor_config_default_initialization():
config = GaussianActorConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -55,9 +55,6 @@ def test_sac_config_default_initialization():
# Basic parameters
assert config.device == "cpu"
assert config.storage_device == "cpu"
assert config.discount == 0.99
assert config.temperature_init == 1.0
assert config.num_critics == 2
# Architecture specifics
assert config.vision_encoder_name is None
@@ -66,6 +63,8 @@ def test_sac_config_default_initialization():
assert config.shared_encoder is True
assert config.num_discrete_actions is None
assert config.image_embedding_pooling_dim == 8
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
# Training parameters
assert config.online_steps == 1000000
@@ -73,20 +72,6 @@ def test_sac_config_default_initialization():
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
assert config.online_step_before_learning == 100
assert config.policy_update_freq == 1
# SAC algorithm parameters
assert config.num_subsample_critics is None
assert config.critic_lr == 3e-4
assert config.actor_lr == 3e-4
assert config.temperature_lr == 3e-4
assert config.critic_target_update_weight == 0.005
assert config.utd_ratio == 1
assert config.state_encoder_hidden_dim == 256
assert config.latent_dim == 256
assert config.target_entropy is None
assert config.use_backup_entropy is True
assert config.grad_clip_norm == 40.0
# Dataset stats defaults
expected_dataset_stats = {
@@ -105,11 +90,6 @@ def test_sac_config_default_initialization():
}
assert config.dataset_stats == expected_dataset_stats
# Critic network configuration
assert config.critic_network_kwargs.hidden_dims == [256, 256]
assert config.critic_network_kwargs.activate_final is True
assert config.critic_network_kwargs.final_activation is None
# Actor network configuration
assert config.actor_network_kwargs.hidden_dims == [256, 256]
assert config.actor_network_kwargs.activate_final is True
@@ -135,7 +115,6 @@ def test_sac_config_default_initialization():
assert config.concurrency.learner == "threads"
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
assert isinstance(config.policy_kwargs, PolicyConfig)
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
assert isinstance(config.concurrency, ConcurrencyConfig)
@@ -175,22 +154,22 @@ def test_concurrency_config():
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
num_critics=3,
latent_dim=128,
state_encoder_hidden_dim=128,
num_discrete_actions=3,
)
assert config.device == "cpu"
assert config.discount == 0.95
assert config.temperature_init == 0.5
assert config.num_critics == 3
assert config.latent_dim == 128
assert config.state_encoder_hidden_dim == 128
assert config.num_discrete_actions == 3
def test_validate_features():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -198,7 +177,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
config = GaussianActorConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -209,7 +188,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -0,0 +1,528 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from torch import Tensor, nn # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 42
set_seed(seed)
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_gaussian_actor_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
GaussianActorPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> GaussianActorConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> GaussianActorConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
return algorithm, policy
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
# squeeze(0) removes batch dim when batch_size==1
assert selected_action.shape[-1] == action_dim
def test_gaussian_actor_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
# Squeeze to unbatched (single observation)
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
selected_action = policy.select_action(observation_batch)
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
with torch.no_grad():
output = policy.forward(batch)
assert "action" in output
assert "log_prob" in output
assert "action_mean" in output
assert output["action"].shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
temp_loss = algorithm._compute_loss_temperature(forward_batch)
assert temp_loss.item() is not None
assert temp_loss.shape == ()
algorithm.optimizers["temperature"].zero_grad()
temp_loss.backward()
algorithm.optimizers["temperature"].step()
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape[-1] == action_dim
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_gaussian_actor_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_gaussian_actor_training_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
def test_gaussian_actor_training_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
config.num_discrete_actions = 5
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
assert discrete_critic_loss.shape == ()
algorithm.optimizers["discrete_critic"].zero_grad()
discrete_critic_loss.backward()
algorithm.optimizers["discrete_critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
# Policy.select_action now handles both continuous + discrete
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
assert selected_action.shape[-1] == continuous_action_dim + 1
def test_sac_algorithm_target_entropy():
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -5.0
def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algorithm, _ = _make_algorithm(config)
assert algorithm.target_entropy == -3.5
def test_sac_algorithm_temperature():
import math
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.temperature == pytest.approx(1.0)
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
assert algorithm.temperature == pytest.approx(0.1)
def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.critic_target_update_weight = 1.0
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in algorithm.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
algorithm._update_target_networks()
for p in algorithm.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data))
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algo_config.num_critics = num_critics
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
assert len(algorithm.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_gaussian_actor_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
with torch.no_grad():
with seeded_context(12):
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert torch.allclose(actions, loaded_actions)
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
state_dim = 10
action_dim = 6
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
-546
View File
@@ -1,546 +0,0 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@pytest.fixture(autouse=True)
def set_random_seed():
seed = 42
set_seed(seed)
def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(10)
y = mlp(x)
assert y.shape == (256,)
def test_mlp_with_batch_dim():
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
x = torch.randn(2, 10)
y = mlp(x)
assert y.shape == (2, 256)
def test_forward_with_empty_hidden_dims():
mlp = MLP(input_dim=10, hidden_dims=[])
x = torch.randn(1, 10)
assert mlp(x).shape == (1, 10)
def test_mlp_with_dropout():
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 11)
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
assert drop_out_layers_count == 2
def test_mlp_with_custom_final_activation():
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
x = torch.randn(1, 10)
y = mlp(x)
assert y.shape == (1, 256)
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
return torch.randn(batch_size, action_dim)
def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
"done": torch.randn(batch_size),
}
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
"""Create optimizers for the SAC policy."""
optimizer_actor = torch.optim.Adam(
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(),
lr=policy.config.critic_lr,
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha],
lr=policy.config.critic_lr,
)
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if has_discrete_action:
optimizers["discrete_critic"] = torch.optim.Adam(
params=policy.discrete_critic.parameters(),
lr=policy.config.critic_lr,
)
return optimizers
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
},
)
config.validate_features()
return config
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}
# Let make tests a little bit faster
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
config.validate_features()
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
# Let's check best candidates for pretrained encoders
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_sac_policy_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
def test_sac_policy_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 # the last action is discrete
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
num_discrete_actions = 5
config.num_discrete_actions = num_discrete_actions
policy = SACPolicy(config=config)
policy.train()
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
policy.train()
optimizers = make_optimizers(policy, has_discrete_action=True)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
assert discrete_critic_loss.item() is not None
assert discrete_critic_loss.shape == ()
discrete_critic_loss.backward()
optimizers["discrete_critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, full_action_dim)
discrete_actions = selected_action[:, -1].long()
discrete_action_values = set(discrete_actions.tolist())
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
)
def test_sac_policy_with_default_entropy():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.target_entropy == -5.0
def test_sac_policy_default_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
policy = SACPolicy(config=config)
assert policy.target_entropy == -3.0
def test_sac_policy_with_predefined_entropy():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.target_entropy = -3.5
policy = SACPolicy(config=config)
assert policy.target_entropy == pytest.approx(-3.5)
def test_sac_policy_update_temperature():
"""Test that temperature property is always in sync with log_alpha."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
# Temperature property automatically reflects log_alpha changes
assert policy.temperature == pytest.approx(0.1)
def test_sac_policy_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
policy = SACPolicy(config=config)
policy.train()
for p in policy.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
policy.update_target_networks()
for p in policy.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data)), (
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
)
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
policy = SACPolicy(config=config)
policy.train()
assert len(policy.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
root = tmp_path / "test_sac_save_and_load"
state_dim = 10
action_dim = 10
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
# They should be the same
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
assert torch.allclose(actor_loss, loaded_actor_loss)
assert torch.allclose(temperature_loss, loaded_temperature_loss)
assert torch.allclose(actions, loaded_actions)
@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
def create_default_config():
"""Create a default SAC configuration for testing."""
config = SACConfig()
config = GaussianActorConfig()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
}
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_sac_processor_normalization_modes():
def test_gaussian_actor_processor_normalization_modes():
"""Test that SAC processor correctly handles different normalization modes."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_cuda():
def test_gaussian_actor_processor_cuda():
"""Test SAC processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_accelerate_scenario():
def test_gaussian_actor_processor_accelerate_scenario():
"""Test SAC processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_sac_processor_multi_gpu():
def test_gaussian_actor_processor_multi_gpu():
"""Test SAC processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
assert processed[TransitionKey.ACTION.value].device == device
def test_sac_processor_without_stats():
def test_gaussian_actor_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
assert processed is not None
def test_sac_processor_save_and_load():
def test_gaussian_actor_processor_save_and_load():
"""Test saving and loading SAC processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_mixed_precision():
def test_gaussian_actor_processor_mixed_precision():
"""Test SAC processor with mixed precision."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
def test_sac_processor_batch_data():
def test_gaussian_actor_processor_batch_data():
"""Test SAC processor with batched data."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
def test_sac_processor_edge_cases():
def test_gaussian_actor_processor_edge_cases():
"""Test SAC processor with edge cases."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
preprocessor, _ = make_gaussian_actor_pre_post_processors(
config,
stats,
)
+14 -7
View File
@@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict():
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
), f"Stats for {key}.{stat_name} should not match original stats"
# Verify that _tensor_stats are also correctly set to match the override stats
# Verify that _tensor_stats values match the override stats
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats
expected_tensor_stats = to_tensor(override_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
torch.testing.assert_close(
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
override_normalizer._tensor_stats[key][stat_name].squeeze(),
expected_tensor_stats[key][stat_name].squeeze(),
)
@@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally():
# Stats should now match the original stats (normal behavior)
# Check that all keys and values match
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats,
# so we squeeze before comparing values.
for key in original_stats:
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
for stat_name in original_stats[key]:
np.testing.assert_allclose(
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
)
actual = new_normalizer.stats[key][stat_name]
expected = original_stats[key][stat_name]
if hasattr(actual, "squeeze"):
actual = actual.squeeze()
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
def test_stats_explicit_provided_flag_detection():
@@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict():
assert ACTION in new_normalizer.stats
# Check that values are correct (converted back from tensors)
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
# Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
-13
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@@ -36,9 +35,6 @@ def test_classifier_output():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -80,9 +76,6 @@ def test_binary_classifier_with_default_params():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -122,9 +115,6 @@ def test_multiclass_classifier():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.rewards.classifier.modeling_classifier import Classifier
@@ -141,9 +131,6 @@ def test_default_device():
@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.rewards.classifier.modeling_classifier import Classifier
+167 -4
View File
@@ -22,12 +22,14 @@ import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("grpc")
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.constants import OBS_STR
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -79,7 +81,7 @@ def cfg():
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
@@ -299,3 +301,164 @@ def test_end_to_end_parameters_flow(cfg, data_size):
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])
def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.transport.utils import state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
batch_size = 4
def batch_iterator():
while True:
yield {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"done": torch.zeros(batch_size),
"complementary_info": {},
}
stats = algorithm.update(batch_iterator())
assert "loss_critic" in stats.losses
# get_weights -> state_to_bytes round-trip
weights = algorithm.get_weights()
assert len(weights) > 0
serialized = state_to_bytes(weights)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
# RLTrainer with DataMixer
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.trainer import RLTrainer
replay_buffer = ReplayBuffer(
capacity=50,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for _ in range(50):
replay_buffer.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(action_dim),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=False,
truncated=False,
)
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
)
trainer_stats = trainer.training_step()
assert "loss_critic" in trainer_stats.losses
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
algorithm.make_optimizers_and_scheduler()
# Simulate initial push (same code path the learner now uses)
initial_weights = algorithm.get_weights()
initial_bytes = state_to_bytes(initial_weights)
# Simulate periodic push
periodic_weights = algorithm.get_weights()
periodic_bytes = state_to_bytes(periodic_weights)
initial_decoded = bytes_to_state_dict(initial_bytes)
periodic_decoded = bytes_to_state_dict(periodic_bytes)
assert initial_decoded.keys() == periodic_decoded.keys()
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
state_dim = 10
action_dim = 6
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
)
sac_cfg.validate_features()
# Actor side: no optimizers
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
# select_action should work
obs = {OBS_STATE: torch.randn(state_dim)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
# Simulate receiving weights from learner
fake_weights = algorithm.get_weights()
algorithm.load_weights(fake_weights, device="cpu")
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
from lerobot.utils.constants import OBS_STATE # noqa: E402
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
buf = ReplayBuffer(
capacity=capacity,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for i in range(capacity):
buf.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(2),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=bool(i % 10 == 9),
truncated=False,
)
return buf
def test_online_only_mixer_sample():
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
batch = mixer.sample(batch_size=8)
assert batch["state"][OBS_STATE].shape[0] == 8
assert batch["action"].shape[0] == 8
assert batch["reward"].shape[0] == 8
def test_online_only_mixer_ratio_one():
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
def test_online_offline_mixer_sample():
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
online = _make_buffer(capacity=50)
offline = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(
online_buffer=online,
offline_buffer=offline,
online_ratio=0.5,
)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
assert batch["action"].shape[0] == 10
# 5 from online, 5 from offline (approx)
assert batch["reward"].shape[0] == 10
def test_online_offline_mixer_iterator():
"""get_iterator yields batches of the requested size."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
batch1 = next(it)
batch2 = next(it)
assert batch1["state"][OBS_STATE].shape[0] == 4
assert batch2["state"][OBS_STATE].shape[0] == 4
+1 -1
View File
@@ -20,7 +20,7 @@ from queue import Queue
import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
+606
View File
@@ -0,0 +1,606 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy # noqa: E402
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats # noqa: E402
from lerobot.rl.algorithms.factory import make_algorithm # noqa: E402
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
# ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_gaussian_actor_policy.py)
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def set_random_seed():
set_seed(42)
def _make_sac_config(
state_dim: int = 10,
action_dim: int = 6,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> GaussianActorConfig:
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
num_discrete_actions=num_discrete_actions,
)
if with_images:
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1).tolist(),
"std": torch.randn(3, 1, 1).abs().tolist(),
}
config.latent_dim = 32
config.state_encoder_hidden_dim = 32
config.validate_features()
return config
def _make_algorithm(
state_dim: int = 10,
action_dim: int = 6,
utd_ratio: int = 1,
policy_update_freq: int = 1,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> tuple[SACAlgorithm, GaussianActorPolicy]:
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = utd_ratio
algo_config.policy_update_freq = policy_update_freq
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers_and_scheduler()
return algorithm, policy
def _make_batch(
batch_size: int = 4,
state_dim: int = 10,
action_dim: int = 6,
with_images: bool = False,
) -> dict:
obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
if with_images:
obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
return {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": obs,
"next_state": next_obs,
"done": torch.zeros(batch_size),
"complementary_info": {},
}
def _batch_iterator(**batch_kwargs):
"""Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator)."""
while True:
yield _make_batch(**batch_kwargs)
# ===========================================================================
# Registry / config tests
# ===========================================================================
def test_sac_algorithm_config_registered():
"""SACAlgorithmConfig should be discoverable through the registry."""
assert "sac" in RLAlgorithmConfig.get_known_choices()
cls = RLAlgorithmConfig.get_choice_class("sac")
assert cls is SACAlgorithmConfig
def test_sac_algorithm_config_from_policy_config():
"""from_policy_config embeds the policy config and uses SAC defaults."""
sac_cfg = _make_sac_config()
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.policy_config is sac_cfg
assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs
# Defaults come from SACAlgorithmConfig, not from the policy config.
assert algo_cfg.utd_ratio == 1
assert algo_cfg.policy_update_freq == 1
assert algo_cfg.grad_clip_norm == 40.0
assert algo_cfg.actor_lr == 3e-4
# ===========================================================================
# TrainingStats tests
# ===========================================================================
def test_training_stats_defaults():
stats = TrainingStats()
assert stats.losses == {}
assert stats.grad_norms == {}
assert stats.extra == {}
# ===========================================================================
# get_weights
# ===========================================================================
def test_get_weights_returns_policy_state_dict():
algorithm, policy = _make_algorithm()
weights = algorithm.get_weights()
assert "policy" in weights
actor_state_dict = policy.actor.state_dict()
for key in actor_state_dict:
assert key in weights["policy"]
assert torch.equal(weights["policy"][key].cpu(), actor_state_dict[key].cpu())
def test_get_weights_includes_discrete_critic_when_present():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
weights = algorithm.get_weights()
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
def test_get_weights_excludes_discrete_critic_when_absent():
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
assert "discrete_critic" not in weights
def test_get_weights_are_on_cpu():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
weights = algorithm.get_weights()
for group_name, state_dict in weights.items():
for key, tensor in state_dict.items():
assert tensor.device == torch.device("cpu"), f"{group_name}/{key} is not on CPU"
# ===========================================================================
# select_action (lives on the policy, not the algorithm)
# ===========================================================================
def test_select_action_returns_correct_shape():
action_dim = 6
_, policy = _make_algorithm(state_dim=10, action_dim=action_dim)
policy.eval()
obs = {OBS_STATE: torch.randn(10)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
def test_select_action_with_discrete_critic():
continuous_dim = 5
_, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
policy.eval()
obs = {OBS_STATE: torch.randn(10)}
action = policy.select_action(obs)
assert action.shape == (continuous_dim + 1,)
# ===========================================================================
# update (single batch, utd_ratio=1)
# ===========================================================================
def test_update_returns_training_stats():
algorithm, _ = _make_algorithm()
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "loss_critic" in stats.losses
assert isinstance(stats.losses["loss_critic"], float)
def test_update_populates_actor_and_temperature_losses():
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
algorithm, _ = _make_algorithm(policy_update_freq=1)
stats = algorithm.update(_batch_iterator())
assert "loss_actor" in stats.losses
assert "loss_temperature" in stats.losses
assert "temperature" in stats.extra
@pytest.mark.parametrize("policy_update_freq", [2, 3])
def test_update_skips_actor_at_non_update_steps(policy_update_freq):
"""Actor/temperature should only update when optimization_step % freq == 0."""
algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq)
it = _batch_iterator()
# Step 0: should update actor
stats_0 = algorithm.update(it)
assert "loss_actor" in stats_0.losses
# Step 1: should NOT update actor
stats_1 = algorithm.update(it)
assert "loss_actor" not in stats_1.losses
def test_update_increments_optimization_step():
algorithm, _ = _make_algorithm()
it = _batch_iterator()
assert algorithm.optimization_step == 0
algorithm.update(it)
assert algorithm.optimization_step == 1
algorithm.update(it)
assert algorithm.optimization_step == 2
def test_update_with_discrete_critic():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
assert "loss_discrete_critic" in stats.losses
assert "discrete_critic" in stats.grad_norms
# ===========================================================================
# update with UTD ratio > 1
# ===========================================================================
@pytest.mark.parametrize("utd_ratio", [2, 4])
def test_update_with_utd_ratio(utd_ratio):
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "loss_critic" in stats.losses
assert algorithm.optimization_step == 1
def test_update_utd_ratio_pulls_utd_batches():
"""next(batch_iterator) should be called exactly utd_ratio times."""
utd_ratio = 3
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
call_count = 0
def counting_iterator():
nonlocal call_count
while True:
call_count += 1
yield _make_batch()
algorithm.update(counting_iterator())
assert call_count == utd_ratio
def test_update_utd_ratio_3_critic_warmup_changes_weights():
"""With utd_ratio=3, critic weights should change after update (3 critic steps)."""
algorithm, policy = _make_algorithm(utd_ratio=3)
critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()}
algorithm.update(_batch_iterator())
changed = False
for n, p in algorithm.critic_ensemble.named_parameters():
if not torch.equal(p, critic_params_before[n]):
changed = True
break
assert changed, "Critic weights should have changed after UTD update"
# ===========================================================================
# get_observation_features
# ===========================================================================
def test_get_observation_features_returns_none_without_frozen_encoder():
algorithm, _ = _make_algorithm(with_images=False)
obs = {OBS_STATE: torch.randn(4, 10)}
next_obs = {OBS_STATE: torch.randn(4, 10)}
feat, next_feat = algorithm.get_observation_features(obs, next_obs)
assert feat is None
assert next_feat is None
# ===========================================================================
# optimization_step setter
# ===========================================================================
def test_optimization_step_can_be_set_for_resume():
algorithm, _ = _make_algorithm()
algorithm.optimization_step = 100
assert algorithm.optimization_step == 100
# ===========================================================================
# make_algorithm factory
# ===========================================================================
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values())
assert algorithm.get_optimizers() is optimizers
def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
def test_make_algorithm_uses_sac_algorithm_defaults():
"""make_algorithm populates SACAlgorithmConfig with its own defaults."""
sac_cfg = _make_sac_config()
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
assert algorithm.config.utd_ratio == 1
assert algorithm.config.policy_update_freq == 1
assert algorithm.config.grad_clip_norm == 40.0
def test_unknown_algorithm_name_raises_in_registry():
"""The ChoiceRegistry is the source of truth for unknown algorithm names."""
with pytest.raises(KeyError):
RLAlgorithmConfig.get_choice_class("unknown_algo")
# ===========================================================================
# load_weights (round-trip with get_weights)
# ===========================================================================
def test_load_weights_round_trip():
"""get_weights -> load_weights should restore identical parameters on a fresh policy."""
algo_src, _ = _make_algorithm(state_dim=10, action_dim=6)
algo_src.update(_batch_iterator())
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
dst_actor_state_dict = algo_dst.policy.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(
dst_actor_state_dict[key].cpu(),
tensor.cpu(),
), f"Policy param '{key}' mismatch after load_weights"
def test_load_weights_round_trip_with_discrete_critic():
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
assert "discrete_critic" in weights
assert len(weights["discrete_critic"]) > 0
dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(
dst_discrete_critic_state_dict[key].cpu(),
tensor.cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights"
def test_load_weights_ignores_missing_discrete_critic():
"""load_weights should not fail when weights lack discrete_critic on a non-discrete policy."""
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
algorithm.load_weights(weights, device="cpu")
def test_actor_side_weight_sync_with_discrete_critic():
"""End-to-end: learner ``algorithm.get_weights()`` -> actor ``algorithm.load_weights()``."""
# Learner side: train the source algorithm so its weights diverge from init.
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
weights = algo_src.get_weights()
# Actor side: fresh policy + fresh algorithm holding it.
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_actor = GaussianActorPolicy(config=sac_cfg)
algo_actor = SACAlgorithm(
policy=policy_actor,
config=SACAlgorithmConfig.from_policy_config(sac_cfg),
)
# Snapshot initial actor state for the "did it change?" assertion below.
initial_discrete_critic_state_dict = {
k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items()
}
algo_actor.load_weights(weights, device="cpu")
# Actor weights match the learner's exported actor state dict.
actor_state_dict = policy_actor.actor.state_dict()
for key, tensor in weights["policy"].items():
assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), (
f"Actor param '{key}' not synced by algorithm.load_weights"
)
# Discrete critic weights match the learner's exported discrete critic.
discrete_critic_state_dict = policy_actor.discrete_critic.state_dict()
for key, tensor in weights["discrete_critic"].items():
assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), (
f"Discrete critic param '{key}' not synced by algorithm.load_weights"
)
# Sanity: the discrete critic actually changed (otherwise the sync is trivial).
changed = any(
not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key])
for key in initial_discrete_critic_state_dict
if key in discrete_critic_state_dict
)
assert changed, "Discrete critic weights did not change between init and after sync"
# ===========================================================================
# TrainingStats generic losses dict
# ===========================================================================
def test_training_stats_generic_losses():
stats = TrainingStats(
losses={"loss_bc": 0.5, "loss_q": 1.2},
extra={"temperature": 0.1},
)
assert stats.losses["loss_bc"] == 0.5
assert stats.losses["loss_q"] == 1.2
assert stats.extra["temperature"] == 0.1
# ===========================================================================
# Registry-driven make_algorithm
# ===========================================================================
def test_make_algorithm_builds_sac():
"""make_algorithm should look up the SAC class from the registry and instantiate it."""
sac_cfg = _make_sac_config()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algo_config.utd_ratio = 2
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(cfg=algo_config, policy=policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.config.utd_ratio == 2
# ===========================================================================
# state_dict / load_state_dict (algorithm-side resume)
# ===========================================================================
def test_state_dict_contains_algorithm_owned_tensors():
"""state_dict should pack critics, target networks, and log_alpha (no encoder bloat)."""
algorithm, _ = _make_algorithm()
sd = algorithm.state_dict()
assert "log_alpha" in sd
assert any(k.startswith("critic_ensemble.") for k in sd)
assert any(k.startswith("critic_target.") for k in sd)
# encoder weights live on the policy and must not be duplicated here.
assert not any(".encoder." in k for k in sd)
def test_state_dict_includes_discrete_critic_target_when_present():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
sd = algorithm.state_dict()
assert any(k.startswith("discrete_critic_target.") for k in sd)
def test_load_state_dict_round_trip_restores_critics_and_log_alpha():
"""state_dict -> load_state_dict on a fresh algorithm restores all bytes exactly."""
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
# Train a few steps so weights diverge from init (action_dim=7 = 6 continuous + 1 discrete).
src.update(_batch_iterator(action_dim=7))
src.update(_batch_iterator(action_dim=7))
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm(policy=dst_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
dst.make_optimizers_and_scheduler()
src_sd = src.state_dict()
dst.load_state_dict(src_sd)
dst_sd = dst.state_dict()
assert set(dst_sd) == set(src_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after round-trip"
def test_load_state_dict_preserves_log_alpha_parameter_identity():
"""The temperature optimizer holds a reference to log_alpha; identity must survive load."""
algorithm, _ = _make_algorithm()
log_alpha_id_before = id(algorithm.log_alpha)
optimizer_param_id = id(algorithm.optimizers["temperature"].param_groups[0]["params"][0])
assert log_alpha_id_before == optimizer_param_id
new_state = algorithm.state_dict()
new_state["log_alpha"] = torch.tensor([0.42])
algorithm.load_state_dict(new_state)
assert id(algorithm.log_alpha) == log_alpha_id_before
assert id(algorithm.optimizers["temperature"].param_groups[0]["params"][0]) == log_alpha_id_before
assert torch.allclose(algorithm.log_alpha.detach().cpu(), torch.tensor([0.42]))
def test_save_pretrained_round_trip_via_disk(tmp_path):
"""End-to-end: save_pretrained -> from_pretrained restores tensors and config."""
sac_cfg = _make_sac_config()
src_policy = GaussianActorPolicy(config=sac_cfg)
src = SACAlgorithm(policy=src_policy, config=SACAlgorithmConfig.from_policy_config(sac_cfg))
src.make_optimizers_and_scheduler()
src.update(_batch_iterator())
save_dir = tmp_path / "algorithm"
src.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
assert (save_dir / "config.json").is_file()
dst_policy = GaussianActorPolicy(config=sac_cfg)
dst = SACAlgorithm.from_pretrained(save_dir, policy=dst_policy)
src_sd = src.state_dict()
dst_sd = dst.state_dict()
assert set(src_sd) == set(dst_sd)
for key in src_sd:
assert torch.allclose(src_sd[key].cpu(), dst_sd[key].cpu()), f"{key} mismatch after disk round-trip"
+133
View File
@@ -0,0 +1,133 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from torch import Tensor # noqa: E402
from lerobot.rl.algorithms.base import RLAlgorithm # noqa: E402
from lerobot.rl.algorithms.configs import TrainingStats # noqa: E402
from lerobot.rl.trainer import RLTrainer # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
class _DummyRLAlgorithmConfig:
"""Dummy config for testing."""
class _DummyRLAlgorithm(RLAlgorithm):
config_class = _DummyRLAlgorithmConfig
name = "dummy_rl_algorithm"
def __init__(self):
self.configure_calls = 0
self.update_calls = 0
def select_action(self, observation: dict[str, Tensor]) -> Tensor:
return torch.zeros(1)
def configure_data_iterator(
self,
data_mixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
):
self.configure_calls += 1
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
def make_optimizers_and_scheduler(self):
return {}
def update(self, batch_iterator):
self.update_calls += 1
_ = next(batch_iterator)
return TrainingStats(losses={"dummy": 1.0})
def load_weights(self, weights, device="cpu") -> None:
_ = (weights, device)
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state_dict, device="cpu") -> None:
_ = (state_dict, device)
class _SimpleMixer:
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
_ = (async_prefetch, queue_size)
while True:
yield {
"state": {OBS_STATE: torch.randn(batch_size, 3)},
ACTION: torch.randn(batch_size, 2),
"reward": torch.randn(batch_size),
"next_state": {OBS_STATE: torch.randn(batch_size, 3)},
"done": torch.zeros(batch_size),
"truncated": torch.zeros(batch_size),
"complementary_info": None,
}
def test_trainer_lazy_iterator_lifecycle_and_reset():
algo = _DummyRLAlgorithm()
mixer = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4)
# First call builds iterator once.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 1
# Second call reuses existing iterator.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 2
# Explicit reset forces lazy rebuild on next step.
trainer.reset_data_iterator()
trainer.training_step()
assert algo.configure_calls == 2
assert algo.update_calls == 3
def test_trainer_set_data_mixer_resets_by_default():
algo = _DummyRLAlgorithm()
mixer_a = _SimpleMixer()
mixer_b = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2)
trainer.training_step()
assert algo.configure_calls == 1
trainer.set_data_mixer(mixer_b, reset=True)
trainer.training_step()
assert algo.configure_calls == 2
def test_algorithm_optimization_step_contract_defaults():
algo = _DummyRLAlgorithm()
assert algo.optimization_step == 0
algo.optimization_step = 11
assert algo.optimization_step == 11
+1 -1
View File
@@ -22,7 +22,7 @@ from unittest.mock import patch
import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
-1
View File
@@ -19,7 +19,6 @@ from collections.abc import Callable
import pytest
pytest.importorskip("grpc")
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
Generated
+916 -659
View File
File diff suppressed because it is too large Load Diff