mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 15:49:53 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9db9c35cb4 | |||
| fe96b28c74 | |||
| 2438df1307 | |||
| f218d5ab30 | |||
| 04125492e4 | |||
| e963e5a0c4 | |||
| 26ff40ddd7 | |||
| 6d269b28c8 | |||
| b607c8458e | |||
| 9e83510c99 | |||
| 1f7b03f5f2 | |||
| cb8edf17e6 | |||
| 5699f6cbf4 | |||
| 0e6114ac36 |
@@ -152,13 +152,14 @@ jobs:
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||
uv pip install \
|
||||
--torch-backend cpu \
|
||||
--index-url https://test.pypi.org/simple/ \
|
||||
--extra-index-url https://pypi.org/simple \
|
||||
--index-strategy unsafe-best-match \
|
||||
"lerobot[all]==$BASE_VERSION"
|
||||
else
|
||||
echo "Installing release version $VERSION from PyPI..."
|
||||
uv pip install "lerobot[all]==$VERSION"
|
||||
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||
fi
|
||||
- name: Check lerobot version
|
||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||
|
||||
@@ -19,8 +19,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Runs at 02:00
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
# schedule:
|
||||
# - cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
|
||||
@@ -232,6 +232,8 @@ Match the policy to the user's **GPU memory** and **time budget**. Numbers below
|
||||
|
||||
All policies typically train for **5–10 epochs** (see §7).
|
||||
|
||||
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||
|
||||
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||
|
||||
@@ -109,7 +109,7 @@ lerobot-train \
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||
|
||||
## Inference & Evaluation
|
||||
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# Video benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
|
||||
| parameter | values |
|
||||
| ----------- | ------------------------------------------------------------ |
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
|
||||
- `pyav`
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
## Metrics
|
||||
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
## Install
|
||||
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
## Adding a video decoder
|
||||
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
For a quick run, you can try these parameters:
|
||||
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
|
||||
We ran the benchmark with the following parameters:
|
||||
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
@@ -1,488 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
|
||||
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -24,6 +24,12 @@
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: hardware_guide
|
||||
title: Compute Hardware Guide
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Compute & Hardware"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
title: Using LeRobotDataset
|
||||
@@ -31,10 +37,8 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -144,10 +148,6 @@
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -1,60 +1,37 @@
|
||||
# Bring Your Own Policies
|
||||
# Adding a Policy
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||
|
||||
### Package Structure
|
||||
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
---
|
||||
|
||||
### Package Configuration
|
||||
## Anatomy of a policy
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
### Configuration class
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
# configuration_my_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@PreTrainedConfig.register_subclass("my_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
class MyPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
raise ValueError("n_action_steps cannot exceed horizon")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
"""Validate input/output feature compatibility.
|
||||
|
||||
Call this explicitly from your policy's __init__ — the base class does not.
|
||||
"""
|
||||
if not self.image_features:
|
||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||
raise ValueError("MyPolicy requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||
return None
|
||||
|
||||
@property
|
||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||
"""
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
return None
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||
|
||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||
### Policy class
|
||||
|
||||
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
# modeling_my_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_custom_policy"
|
||||
class MyPolicy(PreTrainedPolicy):
|
||||
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||
name = "my_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features() # not called automatically by the base class
|
||||
self.config = config
|
||||
self.model = ... # your nn.Module here
|
||||
|
||||
def reset(self):
|
||||
"""Reset episode state."""
|
||||
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||
...
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return a single action for the current timestep (called at inference)."""
|
||||
"""Return a single action for the current timestep (called every step at inference)."""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||
"""Compute the training loss.
|
||||
|
||||
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||
logging-friendly Python natives (no tensors with gradients).
|
||||
|
||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||
timesteps padded because the episode ended before `horizon` steps, you
|
||||
timesteps padded because the episode ended before `horizon` steps; you
|
||||
can exclude those from your loss.
|
||||
"""
|
||||
actions = batch[ACTION]
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
...
|
||||
return {"loss": ...}
|
||||
return loss, {"some_loss_component": some_loss_component.item()}
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
The methods called by the train/eval loops:
|
||||
|
||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
| Method | Used by | What it does |
|
||||
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||
|
||||
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||
|
||||
### Processor functions
|
||||
|
||||
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
# processor_my_policy.py
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
def make_my_policy_pre_post_processors(
|
||||
config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
|
||||
## Step 5: Package Initialization
|
||||
---
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
## Path A: Out-of-tree plugin
|
||||
|
||||
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||
|
||||
### Package structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_policy.py
|
||||
├── modeling_my_policy.py
|
||||
└── processor_my_policy.py
|
||||
```
|
||||
|
||||
### `pyproject.toml`
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.12"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
### Package `__init__.py`
|
||||
|
||||
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
@@ -204,44 +239,148 @@ except ImportError:
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
from .configuration_my_policy import MyPolicyConfig
|
||||
from .modeling_my_policy import MyPolicy
|
||||
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
"MyPolicyConfig",
|
||||
"MyPolicy",
|
||||
"make_my_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
### Install and use
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
cd lerobot_policy_my_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
pip install lerobot_policy_my_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--policy.type my_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
---
|
||||
|
||||
## Path B: Contributing in-tree
|
||||
|
||||
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||
|
||||
### In-tree layout
|
||||
|
||||
```
|
||||
src/lerobot/policies/my_policy/
|
||||
├── __init__.py # re-exports config + modeling + processor factory
|
||||
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||
```
|
||||
|
||||
Two notes:
|
||||
|
||||
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||
|
||||
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||
|
||||
### Wiring
|
||||
|
||||
Three places need to know about your policy. All by name.
|
||||
|
||||
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||
|
||||
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||
|
||||
### Heavy / optional dependencies
|
||||
|
||||
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||
|
||||
```python
|
||||
from typing import TYPE_CHECKING
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
else:
|
||||
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(self, config):
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
...
|
||||
```
|
||||
|
||||
This way:
|
||||
|
||||
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||
- Type checkers see the real symbol.
|
||||
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||
|
||||
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||
|
||||
### Benchmarks and a published checkpoint
|
||||
|
||||
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||
|
||||
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||
|
||||
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||
|
||||
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||
|
||||
```markdown
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 87.5% | 50 |
|
||||
| libero_object | 93.0% | 50 |
|
||||
| libero_goal | 81.5% | 50 |
|
||||
| libero_10 | 62.0% | 50 |
|
||||
| **average** | **81.0%** | 200 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||
```
|
||||
|
||||
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||
|
||||
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||
|
||||
### PR checklist
|
||||
|
||||
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||
|
||||
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||
|
||||
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||
|
||||
---
|
||||
|
||||
## Examples and community contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -0,0 +1,98 @@
|
||||
# Compute HW Guide for LeRobot Training
|
||||
|
||||
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||
|
||||
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||
|
||||
## Memory by policy group
|
||||
|
||||
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||
|
||||
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||
|
||||
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||
|
||||
## Training time
|
||||
|
||||
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||
|
||||
```text
|
||||
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||
total_steps = epochs × steps_per_epoch
|
||||
wall_clock ≈ total_steps × per_step_time
|
||||
```
|
||||
|
||||
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||
|
||||
### Common scenarios
|
||||
|
||||
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||
|
||||
| Setup | Policy | Batch | Wall-clock |
|
||||
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||
|
||||
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||
|
||||
### Multi-GPU matters a lot
|
||||
|
||||
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||
|
||||
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||
|
||||
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||
|
||||
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||
|
||||
### Schedule and checkpoints
|
||||
|
||||
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||
|
||||
## Where to run
|
||||
|
||||
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||
|
||||
| Class | VRAM | Tier | Comfortable for |
|
||||
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||
|
||||
### Hugging Face Jobs
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||
|
||||
```bash
|
||||
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||
bash -c "nvidia-smi && lerobot-train \
|
||||
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
+40
-37
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
|
||||
|
||||
**Example Configuration**
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.2, 0.1],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
|
||||
|
||||
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
||||
|
||||
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
|
||||
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
# bounds for the end-effector in x,y,z direction
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
# maximum step size for the end-effector in x,y,z direction
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
class HILSerlProcessorConfig:
|
||||
...
|
||||
# maximum gripper position that the gripper will be open at
|
||||
max_gripper_pos: float | None = 100.0
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
|
||||
|
||||
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
**Collecting a Dataset for the reward classifier**
|
||||
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
@@ -658,7 +661,7 @@ Example configuration section for data collection:
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
@@ -671,7 +674,7 @@ Example configuration section for data collection:
|
||||
|
||||
**Reward Classifier Configuration**
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
||||
- **model_type**: `"cnn"` or `"transformer"`
|
||||
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"root": null
|
||||
},
|
||||
"policy": {
|
||||
"reward_model": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
**Configuration Setup**
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
|
||||
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
||||
|
||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||
|
||||
### PyTorch CUDA variant (Linux only)
|
||||
|
||||
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
||||
<hfoptions id="cuda_variant">
|
||||
<hfoption id="uv-source">
|
||||
|
||||
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||
|
||||
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||
|
||||
To override for a different CUDA variant:
|
||||
|
||||
```bash
|
||||
uv pip install --force-reinstall torch torchvision \
|
||||
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip-conda">
|
||||
|
||||
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||
|
||||
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||
|
||||
To pick a specific CUDA variant:
|
||||
|
||||
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||
|
||||
```bash
|
||||
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||
pip install -e ".[all]" # source
|
||||
# — or —
|
||||
pip install lerobot # from PyPI
|
||||
```
|
||||
|
||||
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||
|
||||
```bash
|
||||
uv pip install --torch-backend cu128 lerobot
|
||||
```
|
||||
|
||||
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float64 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -28,13 +28,15 @@ lerobot-train \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
--peft.r=64 \
|
||||
--peft.lora_alpha=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
|
||||
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -0,0 +1,136 @@
|
||||
# OMX Follower — Cube Pick And Place Example
|
||||
|
||||
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||
|
||||
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||
|
||||
## Hardware
|
||||
|
||||
| Component | Value |
|
||||
| --------- | ------------------------------------ |
|
||||
| Robot | OMX Follower |
|
||||
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||
|
||||
## Scripts
|
||||
|
||||
| Script | Purpose |
|
||||
| ---------------------- | --------------------------------------------------------------- |
|
||||
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||
|
||||
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||
|
||||
```bash
|
||||
export ROBOT_PORT=/dev/ttyACM0
|
||||
export TELEOP_PORT=/dev/ttyACM1
|
||||
export HF_USERNAME=<your_hf_username>
|
||||
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||
```
|
||||
|
||||
## Step 1 — Collect Data
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--teleop.type=omx_leader \
|
||||
--teleop.port=$TELEOP_PORT \
|
||||
--teleop.id=omx_leader \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
### Bonus Auto-Collect script
|
||||
|
||||
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||
|
||||
```bash
|
||||
python -m examples.omx.record_grab \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--dataset.root=data/omx_pickandplace \
|
||||
--dataset.num_episodes=50 \
|
||||
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Each episode:
|
||||
|
||||
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||
2. The arm returns to HOME.
|
||||
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||
|
||||
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||
|
||||
## Step 2 — Train
|
||||
|
||||
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/omx_pickandplace_act \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||
--steps=20000 \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||
|
||||
## Step 3 — Rollout
|
||||
|
||||
Use the `lerobot-rollout` CLI with base strategy:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
```
|
||||
|
||||
For continuous recording with automatic upload (sentry mode):
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=10 \
|
||||
--robot.type=omx_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=omx_follower \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||
```
|
||||
|
||||
## Environment Reset Utility
|
||||
|
||||
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||
|
||||
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||
|
||||
```bash
|
||||
# Grab cube + place it at a random position on the left side
|
||||
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||
```
|
||||
|
||||
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-record grab episodes for the OMX robot arm.
|
||||
|
||||
Each episode cycle:
|
||||
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||
2. HOME — return arm to home with gripper open
|
||||
3. record_grab — execute a targeted grab to the stored position while recording
|
||||
observations + actions to a LeRobotDataset
|
||||
|
||||
Usage (run from repo root):
|
||||
python -m examples.omx.record_grab \\
|
||||
--robot.type=omx_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.id=omx_follower \\
|
||||
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||
--dataset.root=data/omx_grab \\
|
||||
--dataset.num_episodes=50 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||
from lerobot.robots.omx_follower import OmxFollower
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from .reset_environment import (
|
||||
APPROACH_SPEED,
|
||||
GRIPPER_CLOSE_POS,
|
||||
HOME_POSE,
|
||||
PUSH_END_ELBOW_FLEX,
|
||||
PUSH_END_SHOULDER_LIFT,
|
||||
PUSH_START_ELBOW_FLEX,
|
||||
PUSH_START_SHOULDER_LIFT,
|
||||
array_to_pose,
|
||||
grab_cube,
|
||||
horizontal_wrist_flex,
|
||||
move_to_pose,
|
||||
place_cube,
|
||||
pose_to_array,
|
||||
)
|
||||
|
||||
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||
|
||||
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||
GRAB_RAISE_SL_OFFSET = 20.0
|
||||
GRAB_LOWER_SPEED = 20.0
|
||||
RECORD_SPEED = 30.0
|
||||
|
||||
# Pose the arm travels to after closing the gripper (cube held).
|
||||
GRAB_CARRY_POSE = {
|
||||
"shoulder_pan.pos": -23.0,
|
||||
"shoulder_lift.pos": 5.0,
|
||||
"elbow_flex.pos": 18.0,
|
||||
"wrist_flex.pos": -14.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||
_JITTER_LIMITS: dict[str, float] = {
|
||||
"shoulder_pan.pos": 5.0,
|
||||
"shoulder_lift.pos": 4.0,
|
||||
"elbow_flex.pos": 4.0,
|
||||
"wrist_flex.pos": 3.0,
|
||||
"wrist_roll.pos": 2.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||
return {
|
||||
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||
}
|
||||
|
||||
|
||||
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||
|
||||
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||
table-safe envelope across the full sl range:
|
||||
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||
wrist_flex is randomly offset from the horizontal value.
|
||||
"""
|
||||
pan = float(rng.uniform(-5.0, 35.0))
|
||||
sl = float(rng.uniform(-50.0, 30.0))
|
||||
|
||||
if sl <= 0.0:
|
||||
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||
ef_lo = alpha * -25.0 # 0 → -25
|
||||
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||
else:
|
||||
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||
|
||||
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf,
|
||||
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmxRecordGrabConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||
recovery_prob: float = 0.5
|
||||
|
||||
|
||||
def record_episode_spline(
|
||||
robot: OmxFollower,
|
||||
waypoints: list[dict],
|
||||
speeds: list[float],
|
||||
dataset: LeRobotDataset,
|
||||
task: str,
|
||||
) -> None:
|
||||
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||
|
||||
Segment durations are parameterized from the maximum absolute joint delta
|
||||
between consecutive waypoints divided by the requested segment speed,
|
||||
producing non-uniform timing in joint space. Interior tangents are derived
|
||||
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||
Hermite, giving C1 continuity at every waypoint.
|
||||
"""
|
||||
pts = [pose_to_array(w) for w in waypoints]
|
||||
n = len(pts)
|
||||
|
||||
# Steps and duration per segment
|
||||
n_steps_list = []
|
||||
timestamps = []
|
||||
for i in range(n - 1):
|
||||
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||
n_steps_list.append(ns)
|
||||
timestamps.append(ns / dataset.fps)
|
||||
|
||||
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||
vels = [np.zeros_like(pts[0])]
|
||||
for i in range(1, n - 1):
|
||||
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||
vels.append(0.5 * (v_prev + v_next))
|
||||
vels.append(np.zeros_like(pts[0]))
|
||||
|
||||
dt = 1.0 / dataset.fps
|
||||
for seg in range(n - 1):
|
||||
ns = n_steps_list[seg]
|
||||
if ns == 0:
|
||||
continue
|
||||
p0, p1 = pts[seg], pts[seg + 1]
|
||||
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||
m0 = vels[seg] * timestamps[seg]
|
||||
m1 = vels[seg + 1] * timestamps[seg]
|
||||
|
||||
for step in range(1, ns + 1):
|
||||
t = step / ns
|
||||
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||
h10 = t**3 - 2 * t**2 + t
|
||||
h01 = -2 * t**3 + 3 * t**2
|
||||
h11 = t**3 - t**2
|
||||
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||
|
||||
action = array_to_pose(commanded)
|
||||
robot.send_action(action)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
def record_grab_episode(
|
||||
robot: OmxFollower,
|
||||
dataset: LeRobotDataset,
|
||||
pan: float,
|
||||
t: float,
|
||||
task: str,
|
||||
recovery_start: bool = False,
|
||||
) -> None:
|
||||
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||
|
||||
Normal sequence (initial HOME move is NOT recorded):
|
||||
HOME → raised approach above cube → lower → close gripper
|
||||
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||
|
||||
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||
(gripper closed) without recording, then recording begins from there:
|
||||
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||
|
||||
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||
"""
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
if recovery_start:
|
||||
stuck_pose = _random_stuck_pose(rng)
|
||||
logger.info(f"Recovery start: {stuck_pose}")
|
||||
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||
first_waypoints = [stuck_pose]
|
||||
first_speeds = []
|
||||
else:
|
||||
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||
first_waypoints = [jittery_start]
|
||||
first_speeds = []
|
||||
|
||||
waypoints = first_waypoints + [
|
||||
{ # raised approach: arm above cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # lower onto cube — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{ # close gripper — no jitter: precision needed
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": wf_horizontal,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
_jitter_pose(
|
||||
{ # raise with cube
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl_raised,
|
||||
"elbow_flex.pos": ef,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
_jitter_pose(
|
||||
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||
"shoulder_pan.pos": pan * 0.25,
|
||||
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||
},
|
||||
rng,
|
||||
),
|
||||
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||
HOME_POSE,
|
||||
]
|
||||
speeds = first_speeds + [
|
||||
RECORD_SPEED, # (HOME →) raised approach
|
||||
GRAB_LOWER_SPEED, # raised approach → lower
|
||||
GRAB_LOWER_SPEED, # lower → close gripper
|
||||
RECORD_SPEED, # close gripper → raise
|
||||
RECORD_SPEED, # raise → retract
|
||||
RECORD_SPEED, # retract → carry pose
|
||||
RECORD_SPEED, # carry pose → drop
|
||||
RECORD_SPEED, # drop → HOME
|
||||
]
|
||||
|
||||
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||
|
||||
# Dwell at HOME for ~0.5 s before next episode
|
||||
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||
dt = 1.0 / dataset.fps
|
||||
for _ in range(int(dataset.fps * 0.5)):
|
||||
robot.send_action(HOME_POSE)
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger.info(pformat(cfg))
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
use_videos = cfg.dataset.video
|
||||
|
||||
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_obs_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=use_videos,
|
||||
),
|
||||
)
|
||||
|
||||
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||
dataset = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
else:
|
||||
cfg.dataset.stamp_repo_id()
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
with VideoEncodingManager(dataset):
|
||||
for episode_idx in range(cfg.dataset.num_episodes):
|
||||
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||
|
||||
logger.info("Step 1: grabbing and placing cube...")
|
||||
grab_cube(robot)
|
||||
pan, t = place_cube(robot)
|
||||
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||
|
||||
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||
record_grab_episode(
|
||||
robot,
|
||||
dataset,
|
||||
pan,
|
||||
t,
|
||||
cfg.dataset.single_task,
|
||||
recovery_start=recovery_start,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||
|
||||
finally:
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
record_grab()
|
||||
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||
|
||||
Provides:
|
||||
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||
- place_cube(robot): carry cube to a random position, release
|
||||
|
||||
Standalone usage (run from repo root):
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||
|
||||
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||
|
||||
To read current joint values for calibration, add after robot.connect():
|
||||
obs = robot.get_observation()
|
||||
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||
robot.disconnect(); raise SystemExit
|
||||
|
||||
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
HOME_POSE = {
|
||||
"shoulder_pan.pos": 0.0,
|
||||
"shoulder_lift.pos": -50.0,
|
||||
"elbow_flex.pos": 50.0,
|
||||
"wrist_flex.pos": 0.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
SWEEP_WAYPOINTS = [
|
||||
{
|
||||
"shoulder_pan.pos": -60.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -20.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": -30.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -60.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
{
|
||||
"shoulder_pan.pos": 20.0,
|
||||
"shoulder_lift.pos": 50.0,
|
||||
"elbow_flex.pos": -55.0,
|
||||
"wrist_flex.pos": -5.0,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
},
|
||||
]
|
||||
|
||||
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||
|
||||
CONTROL_HZ = 30
|
||||
APPROACH_SPEED = 50.0
|
||||
SWEEP_SPEED = 40.0
|
||||
|
||||
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||
|
||||
GRAB_PAN = 0.0
|
||||
SWEEP_LEFT_PAN = -60.0
|
||||
SWEEP_RIGHT_PAN = 60.0
|
||||
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||
|
||||
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||
|
||||
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||
|
||||
PUSH_START_SHOULDER_LIFT = 0.0
|
||||
PUSH_START_ELBOW_FLEX = 45.0
|
||||
PUSH_END_SHOULDER_LIFT = 50.0
|
||||
PUSH_END_ELBOW_FLEX = -50.0
|
||||
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||
# Does not affect the grab-target interpolation in record_grab.py.
|
||||
PUSH_RAISE_OFFSET = 5.0
|
||||
|
||||
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||
GRIPPER_CLOSE_POS = 50.0
|
||||
|
||||
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||
|
||||
JOINT_NAMES = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def pose_to_array(pose: dict) -> np.ndarray:
|
||||
return np.array([pose[k] for k in JOINT_NAMES])
|
||||
|
||||
|
||||
def array_to_pose(arr: np.ndarray) -> dict:
|
||||
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||
|
||||
|
||||
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||
|
||||
|
||||
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": sl,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": 60.0,
|
||||
}
|
||||
|
||||
|
||||
def _high_sweep_pose(pan: float) -> dict:
|
||||
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||
|
||||
|
||||
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||
return {
|
||||
"shoulder_pan.pos": pan,
|
||||
"shoulder_lift.pos": shoulder_lift,
|
||||
"elbow_flex.pos": elbow_flex,
|
||||
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||
"wrist_roll.pos": 0.0,
|
||||
"gripper.pos": gripper,
|
||||
}
|
||||
|
||||
|
||||
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||
obs = robot.get_observation()
|
||||
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||
goal = pose_to_array(target)
|
||||
|
||||
max_distance = float(np.max(np.abs(goal - current)))
|
||||
if max_distance < 0.5:
|
||||
return
|
||||
|
||||
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||
dt = 1.0 / CONTROL_HZ
|
||||
for step in range(1, n_steps + 1):
|
||||
t = step / n_steps
|
||||
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||
precise_sleep(dt)
|
||||
|
||||
|
||||
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def grab_cube(robot: Robot) -> None:
|
||||
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
for pan, end_pan in [
|
||||
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||
]:
|
||||
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
|
||||
logger.info("Extending to push cube into gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||
SWEEP_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Closing gripper...")
|
||||
move_to_pose(
|
||||
robot,
|
||||
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||
APPROACH_SPEED,
|
||||
)
|
||||
|
||||
logger.info("Grab complete.")
|
||||
|
||||
|
||||
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||
|
||||
Returns:
|
||||
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||
"""
|
||||
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||
|
||||
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||
move_to_pose(
|
||||
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||
)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||
logger.info("Place complete.")
|
||||
return pan, t
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||
parser.add_argument("--robot_id", default="omx_follower")
|
||||
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||
robot.connect(calibrate=True)
|
||||
|
||||
try:
|
||||
if args.mode == "grab":
|
||||
grab_cube(robot)
|
||||
elif args.mode == "grab_and_place":
|
||||
grab_cube(robot)
|
||||
place_cube(robot)
|
||||
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -113,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -261,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
+30
-5
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
"torch>=2.7,<2.12.0",
|
||||
"torchvision>=0.22.0,<0.27.0",
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
@@ -95,11 +95,22 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
|
||||
# NOTE: torchcodec wheel availability matrix (PyPI):
|
||||
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
|
||||
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
|
||||
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
|
||||
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
|
||||
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
|
||||
|
||||
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
|
||||
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
|
||||
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -195,7 +206,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -293,6 +304,20 @@ lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||
# uv pip install --force-reinstall torch torchvision \
|
||||
# --index-url https://download.pytorch.org/whl/cu130
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -24,7 +24,6 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -44,10 +43,7 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
]
|
||||
|
||||
@@ -117,3 +117,9 @@ class PeftConfig:
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
|
||||
# In general, a higher alpha means stronger adaptation signal.
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import pkgutil
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
@@ -24,6 +26,7 @@ from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
# Storage for path args extracted from YAML/JSON config files, so that
|
||||
# get_path_arg() can find them even when they weren't passed via CLI.
|
||||
_config_path_args: dict[str, str] = {}
|
||||
|
||||
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
|
||||
_config_yaml_overrides: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
|
||||
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
|
||||
args = []
|
||||
for key, value in d.items():
|
||||
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
|
||||
continue
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
if isinstance(value, dict):
|
||||
args.extend(_flatten_to_cli_args(value, full_key))
|
||||
elif value is not None and not isinstance(value, list):
|
||||
args.append(f"--{full_key}={value}")
|
||||
return args
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
if result is None:
|
||||
result = _config_path_args.get(field_name)
|
||||
return result
|
||||
|
||||
|
||||
def get_yaml_overrides(field_name: str) -> list[str]:
|
||||
return _config_yaml_overrides.get(field_name, [])
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
@@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
|
||||
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
|
||||
|
||||
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
|
||||
draccus will fail because ``path`` is not a valid field on policy config classes.
|
||||
This function extracts those path values, stores them in ``_config_path_args`` for
|
||||
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
|
||||
"""
|
||||
config_file = Path(config_path)
|
||||
suffix = config_file.suffix.lower()
|
||||
|
||||
if suffix in (".yaml", ".yml"):
|
||||
with open(config_file) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
elif suffix == ".json":
|
||||
with open(config_file) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config_data, dict):
|
||||
return config_path
|
||||
|
||||
modified = False
|
||||
for field in path_fields:
|
||||
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
|
||||
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
else:
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
return config_path
|
||||
|
||||
# Write cleaned config to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
||||
if suffix in (".yaml", ".yml"):
|
||||
yaml.dump(config_data, tmp, default_flow_style=False)
|
||||
else:
|
||||
json.dump(config_data, tmp, indent=2)
|
||||
return tmp.name
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
@@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
# Also extract path fields from the YAML/JSON config file
|
||||
if config_path_cli:
|
||||
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -256,7 +259,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||
if config_file is not None and config_file.endswith(".json"):
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
@@ -267,10 +272,3 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -37,14 +37,6 @@ from .dataset_tools import (
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
@@ -61,15 +53,10 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"add_features",
|
||||
@@ -79,7 +66,6 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
if features[key]["dtype"] == "string":
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -34,6 +35,7 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -174,6 +176,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -187,6 +190,29 @@ class LeRobotDatasetMetadata:
|
||||
if self.episodes is None:
|
||||
self._load_metadata()
|
||||
|
||||
def filter_episodes(
|
||||
self,
|
||||
predicate: Callable[[dict], bool],
|
||||
candidates: list[int] | None = None,
|
||||
) -> list[int]:
|
||||
"""Filter episodes whose metadata satisfies a given predicate.
|
||||
|
||||
Args:
|
||||
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||
candidates: Optional list of episode indices to restrict evaluation to.
|
||||
|
||||
Returns:
|
||||
List of sorted episode indices that satisfy the predicate.
|
||||
"""
|
||||
self.ensure_readable()
|
||||
if candidates is not None:
|
||||
candidate_set = set(candidates)
|
||||
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
||||
else:
|
||||
combined = predicate
|
||||
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
||||
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||
|
||||
def _pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
@@ -316,39 +342,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -664,6 +657,7 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,4 +295,9 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -22,12 +22,6 @@ from PIL import Image as PILImage
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -52,13 +46,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -254,8 +242,6 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return ""
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,6 +186,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -257,13 +265,11 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -298,9 +304,8 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
@@ -1,543 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
|
||||
is also a float64 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz)."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = (
|
||||
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
||||
)
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
value = row["timestamp"]
|
||||
return float(value.item() if hasattr(value, "item") else value)
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
episode_filter: Callable[[dict], bool] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||
Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
@@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
@@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
if episodes is not None and any(
|
||||
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||
):
|
||||
logger.warning(
|
||||
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||
)
|
||||
|
||||
if episode_filter is not None:
|
||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||
)
|
||||
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||
episodes = resolved
|
||||
self.episodes = episodes
|
||||
|
||||
# Create reader (hf_dataset loaded below)
|
||||
self.reader = DatasetReader(
|
||||
meta=self.meta,
|
||||
|
||||
@@ -88,6 +88,7 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -129,9 +130,6 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -153,15 +151,11 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -33,7 +33,6 @@ import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
@@ -132,7 +131,9 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
|
||||
@@ -145,85 +146,87 @@ def decode_video_frames(
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||
)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
def decode_video_frames_pyav(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
|
||||
x86_64 and linux armv7l — see the torchcodec block in pyproject.toml for the full matrix).
|
||||
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
|
||||
accurate seek.
|
||||
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
|
||||
forward until we have covered the requested timestamp range. The number of key frames in a
|
||||
video can be adjusted at encoding time to trade off decoding speed against file size.
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps (in seconds) to extract frames for.
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
"""
|
||||
video_path = str(video_path)
|
||||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
video_path = str(video_path)
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
# load all frames until last requested frame
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
|
||||
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
|
||||
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
|
||||
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
container.seek(int(first_ts * av.time_base), backward=True)
|
||||
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
reader = None
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
@@ -234,14 +237,14 @@ def decode_video_frames_torchvision(
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
f"\nbackend: pyav"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
closest_ts = loaded_ts_t[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
@@ -282,7 +285,11 @@ class VideoDecoderCache:
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
return self._cache[video_path][0]
|
||||
|
||||
@@ -18,13 +18,13 @@ from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -32,21 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"SACConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -47,12 +47,12 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -88,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -127,10 +127,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
return SACPolicy
|
||||
return GaussianActorPolicy
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
@@ -167,7 +167,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -191,8 +191,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -365,10 +365,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
+4
-4
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
+37
-59
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Encoder architecture
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
|
||||
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
# Network architecture
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters (Gaussian head)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
# Any validation specific to GaussianActor configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
|
||||
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
|
||||
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
|
||||
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
|
||||
# this preset.
|
||||
default_lr = 3e-4
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
"actor": {"lr": default_lr},
|
||||
"critic": {"lr": default_lr},
|
||||
"temperature": {"lr": default_lr},
|
||||
},
|
||||
)
|
||||
|
||||
+48
-447
@@ -15,16 +15,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
config: GaussianActorConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -54,9 +49,8 @@ class SACPolicy(
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
self._init_discrete_critic()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
@@ -65,11 +59,7 @@ class SACPolicy(
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
@@ -79,7 +69,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -92,360 +84,43 @@ class SACPolicy(
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
if self.discrete_critic is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
else:
|
||||
discrete_action = torch.ones(
|
||||
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||
"""Actor forward pass: sample actions and return log-probabilities.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
batch: A flat observation dict, or a training dict containing
|
||||
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||
(pre-computed encoder features).
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
"""Initialize policy actor network."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
@@ -455,21 +130,25 @@ class SACPolicy(
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
def _init_discrete_critic(self) -> None:
|
||||
"""Initialize discrete critic network."""
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -677,84 +356,6 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class DiscreteCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -811,7 +412,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -885,7 +486,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
+5
-5
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -95,13 +95,6 @@ from .relative_action_processor import (
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
# RenderMessagesStep is intentionally NOT re-exported here: it pulls in
|
||||
# `lerobot.datasets.language`, which requires the `[dataset]` extra
|
||||
# (`datasets`, `pyarrow`). Importing it from the processor package would
|
||||
# break every base-install consumer of `lerobot.processor`. Users that
|
||||
# need it import directly:
|
||||
# from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
|
||||
@@ -174,24 +174,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -153,30 +153,26 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
@@ -321,6 +320,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -330,6 +330,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if DISCRETE_PENALTY_KEY in info:
|
||||
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
@@ -348,18 +351,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
Applies a small per-transition cost on the discrete gripper action.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
Fires only when the commanded action would actually transition the gripper
|
||||
from one extreme to the other (close-while-open or open-while-closed).
|
||||
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||
commands unpenalized.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
open_threshold: Normalized state below which the gripper is considered "open".
|
||||
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
penalty: float = -0.02
|
||||
max_gripper_pos: float = 30.0
|
||||
open_threshold: float = 0.1
|
||||
closed_threshold: float = 0.9
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -379,11 +388,15 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
if raw_joint_positions is None:
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
current_gripper_pos = raw_joint_positions.get(f"{GRIPPER_KEY}.pos", None)
|
||||
if current_gripper_pos is None:
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
# During reset, the transition may not carry any action yet.
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
# Gripper action is expected as the last action dimension.
|
||||
gripper_action = action[-1].item()
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
@@ -391,9 +404,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
# - currently open AND target is closed -> close transition
|
||||
# - currently closed AND target is open -> open transition
|
||||
is_open = gripper_state_normalized < self.open_threshold
|
||||
is_closed = gripper_state_normalized > self.closed_threshold
|
||||
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||
cmd_open = gripper_action_normalized < self.open_threshold
|
||||
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
@@ -409,11 +426,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
A dictionary containing the penalty value, max gripper position,
|
||||
and the open/closed thresholds.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
"open_threshold": self.open_threshold,
|
||||
"closed_threshold": self.closed_threshold,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -134,6 +134,24 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
|
||||
|
||||
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
|
||||
(already ``(C, 1, 1)``). Needed by RL training, which can start without
|
||||
a dataset and supplies stats manually via JSON config.
|
||||
"""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
|
||||
continue
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -152,6 +170,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -201,6 +220,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -211,6 +231,7 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=_scalar(timestamp),
|
||||
sample_idx=int(_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
@@ -30,7 +30,7 @@ class RewardClassifierConfig(RewardModelConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
model_name: str = "lerobot/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
@@ -105,6 +105,7 @@ class Classifier(PreTrainedRewardModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
**kwargs,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
|
||||
+26
-16
@@ -12,23 +12,33 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Reinforcement learning modules.
|
||||
"""Reinforcement learning modules.
|
||||
|
||||
Requires: ``pip install 'lerobot[hilserl]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.rl.actor import ...
|
||||
from lerobot.rl.learner import ...
|
||||
from lerobot.rl.learner_service import ...
|
||||
from lerobot.rl.buffer import ...
|
||||
from lerobot.rl.eval_policy import ...
|
||||
from lerobot.rl.gym_manipulator import ...
|
||||
Distributed actor / learner entry points (``actor``, ``learner``,
|
||||
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
|
||||
buffer, data sources and trainer are gRPC-free and usable standalone.
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||
from .algorithms.factory import (
|
||||
make_algorithm as make_algorithm,
|
||||
make_algorithm_config as make_algorithm_config,
|
||||
)
|
||||
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .buffer import ReplayBuffer as ReplayBuffer
|
||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||
from .trainer import RLTrainer as RLTrainer
|
||||
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"make_algorithm",
|
||||
"make_algorithm_config",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTrainer",
|
||||
"ReplayBuffer",
|
||||
"DataMixer",
|
||||
"OnlineOfflineMixer",
|
||||
]
|
||||
|
||||
+113
-78
@@ -49,39 +49,53 @@ https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2 = None
|
||||
services_pb2_grpc = None
|
||||
bytes_to_state_dict = None
|
||||
grpc_channel_options = None
|
||||
python_object_to_bytes = None
|
||||
receive_bytes_in_chunks = None
|
||||
send_bytes_in_chunks = None
|
||||
transitions_to_bytes = None
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
@@ -89,19 +103,24 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.factory import make_algorithm
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
reset_and_build_transition,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .queue import get_last_item_from_queue
|
||||
from .train_rl import TrainRLServerPipelineConfig
|
||||
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
@@ -212,7 +231,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
@@ -252,22 +271,24 @@ def act_with_policy(
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### To avoid sending a policy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
policy = policy.to(device).eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
# Build the algorithm
|
||||
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -291,8 +312,17 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# Unnormalize only the continuous part.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
device=continuous_action.device, dtype=continuous_action.dtype
|
||||
)
|
||||
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||
else:
|
||||
action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -326,7 +356,8 @@ def act_with_policy(
|
||||
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
|
||||
if is_intervention:
|
||||
episode_intervention = True
|
||||
episode_intervention_steps += 1
|
||||
|
||||
@@ -334,6 +365,7 @@ def act_with_policy(
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
@@ -354,7 +386,7 @@ def act_with_policy(
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -390,14 +422,7 @@ def act_with_policy(
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
@@ -408,10 +433,10 @@ def act_with_policy(
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
stub: "services_pb2_grpc.LearnerServiceStub",
|
||||
shutdown_event: Any, # Event
|
||||
attempts: int = 30,
|
||||
):
|
||||
) -> bool:
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
@@ -441,12 +466,14 @@ def establish_learner_connection(
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
) -> "tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]":
|
||||
"""Return a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
|
||||
Returns:
|
||||
tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: The stub and the channel.
|
||||
"""
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
@@ -461,16 +488,18 @@ def learner_service_client(
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
@@ -513,12 +542,11 @@ def receive_policy(
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Send transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
@@ -526,6 +554,13 @@ def send_transitions(
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
transitions_queue (Queue): The queue to receive the transitions.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
@@ -563,18 +598,24 @@ def send_transitions(
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: "services_pb2_grpc.LearnerServiceStub | None" = None,
|
||||
grpc_channel: "grpc.Channel | None" = None,
|
||||
) -> None:
|
||||
"""Send interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
interactions_queue (Queue): The queue to receive the interactions.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
learner_client (services_pb2_grpc.LearnerServiceStub | None): Optional pre-created stub.
|
||||
grpc_channel (grpc.Channel | None): Optional pre-created channel.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
@@ -613,7 +654,11 @@ def send_interactions(
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
def transitions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
transitions_queue: Queue,
|
||||
timeout: float,
|
||||
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
@@ -629,10 +674,10 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event,
|
||||
shutdown_event: Any, # Event
|
||||
interactions_queue: Queue,
|
||||
timeout: float, # type: ignore
|
||||
) -> services_pb2.Empty:
|
||||
timeout: float,
|
||||
) -> "Generator[Any, None, services_pb2.Empty]":
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=timeout)
|
||||
@@ -652,7 +697,8 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
|
||||
"""Drain the latest learner-pushed weights into ``algorithm.policy``."""
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
@@ -667,18 +713,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
algorithm.load_weights(state_dicts, device=device)
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
__all__ = [
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
]
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.types import BatchType
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .configs import RLAlgorithmConfig, TrainingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
from ..data_sources.data_mixer import DataMixer
|
||||
|
||||
T = TypeVar("T", bound="RLAlgorithm")
|
||||
|
||||
|
||||
class RLAlgorithm(HubMixin, abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
config_class: type[RLAlgorithmConfig]
|
||||
name: str
|
||||
config: RLAlgorithmConfig
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""Build and return the optimizers used during training.
|
||||
|
||||
Called once on the learner side after construction.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Algorithm-owned trainable tensors.
|
||||
|
||||
Must return a flat tensor mapping for everything the algorithm owns
|
||||
that is not part of the policy (e.g. critic ensembles, target networks,
|
||||
temperature parameters). Algorithms with no training-only tensors
|
||||
should explicitly return an empty dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
device: str | torch.device = "cpu",
|
||||
) -> None:
|
||||
"""In-place load of algorithm-owned tensors.
|
||||
|
||||
Implementations MUST keep the identity of any ``nn.Parameter`` that an
|
||||
optimizer references (e.g. SAC's ``log_alpha``) by using ``.copy_()``
|
||||
rather than rebinding the attribute.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Persist the algorithm's tensors and config to ``save_directory``.
|
||||
|
||||
Writes ``model.safetensors`` (algorithm tensors via :meth:`state_dict`)
|
||||
and ``config.json`` (via :meth:`RLAlgorithmConfig.save_pretrained`).
|
||||
"""
|
||||
tensors = {k: v.detach().cpu().contiguous() for k, v in self.state_dict().items()}
|
||||
save_safetensors(tensors, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
self.config._save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
policy: nn.Module,
|
||||
config: RLAlgorithmConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
device: str | torch.device = "cpu",
|
||||
**algo_kwargs: Any,
|
||||
) -> T:
|
||||
"""Build an algorithm and load its weights from ``pretrained_name_or_path``."""
|
||||
if config is None:
|
||||
config = cls.config_class.from_pretrained(
|
||||
pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if hasattr(config, "policy_config"):
|
||||
config.policy_config = policy.config
|
||||
|
||||
instance = cls(policy=policy, config=config, **algo_kwargs)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
if os.path.isdir(model_id):
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
tensors = load_safetensors(model_file)
|
||||
instance.load_state_dict(tensors, device=device)
|
||||
return instance
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RLAlgorithmConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Serialize this config as ``config.json`` inside ``save_directory``."""
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**algo_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
with draccus.config_type("json"):
|
||||
instance = draccus.parse(RLAlgorithmConfig, config_file, args=[])
|
||||
|
||||
if cls is not RLAlgorithmConfig and not isinstance(instance, cls):
|
||||
raise TypeError(
|
||||
f"Config at {model_id} has type '{instance.type}' but was loaded via "
|
||||
f"{cls.__name__}; use the matching subclass or RLAlgorithmConfig.from_pretrained()."
|
||||
)
|
||||
|
||||
for key, value in algo_kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from .base import RLAlgorithm
|
||||
from .configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
|
||||
|
||||
Args:
|
||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||
**kwargs: Keyword arguments forwarded to the config class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the matching ``RLAlgorithmConfig`` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``algorithm_type`` is not registered.
|
||||
"""
|
||||
try:
|
||||
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Algorithm type '{algorithm_type}' is not registered. "
|
||||
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def get_algorithm_class(name: str) -> type[RLAlgorithm]:
|
||||
"""
|
||||
Retrieves an RL algorithm class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all algorithm classes into
|
||||
memory at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the algorithm. Supported names are "sac".
|
||||
|
||||
Returns:
|
||||
The algorithm class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the algorithm name is not recognized.
|
||||
"""
|
||||
if name == "sac":
|
||||
from .sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm
|
||||
raise ValueError(
|
||||
f"Algorithm type '{name}' is not available. "
|
||||
f"Known: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||
)
|
||||
|
||||
|
||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""
|
||||
Instantiate an RL algorithm.
|
||||
|
||||
This factory function looks up the :class:`RLAlgorithm` subclass that matches
|
||||
``cfg.type`` and instantiates it with the provided policy. It also enforces
|
||||
that ``cfg.policy_config`` has been populated before construction (this is
|
||||
normally handled by :meth:`TrainRLServerPipelineConfig.validate`).
|
||||
|
||||
Args:
|
||||
cfg: The algorithm configuration. Must have ``policy_config`` set.
|
||||
policy: The policy module the algorithm will train.
|
||||
|
||||
Returns:
|
||||
An instantiated :class:`RLAlgorithm`.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``cfg.policy_config`` is ``None`` or ``cfg.type`` is not
|
||||
registered.
|
||||
"""
|
||||
if getattr(cfg, "policy_config", None) is None:
|
||||
raise ValueError(
|
||||
f"{type(cfg).__name__}.policy_config is None. "
|
||||
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||
"before calling make_algorithm()."
|
||||
)
|
||||
cls = get_algorithm_class(cfg.type)
|
||||
return cls(policy=policy, config=cfg)
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACAlgorithmConfig
|
||||
from .sac_algorithm import SACAlgorithm
|
||||
|
||||
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
)
|
||||
|
||||
from ..configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""Soft Actor-Critic (SAC) algorithm configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum
|
||||
entropy reinforcement learning framework. It learns a policy and a Q-function
|
||||
simultaneously using experience collected from the environment.
|
||||
|
||||
This configuration class contains the algorithm-side hyperparameters: critic
|
||||
ensemble, target networks, temperature / entropy tuning, and the Bellman
|
||||
update loop. The policy-side (actor + observation encoder) lives in
|
||||
:class:`~lerobot.policies.gaussian_actor.GaussianActorConfig` and is
|
||||
referenced via :attr:`policy_config`.
|
||||
"""
|
||||
|
||||
# Optimizer learning rates
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
|
||||
# Bellman update
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
||||
# Critic ensemble
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
# Temperature / entropy
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Target entropy for automatic temperature tuning. If ``None``, defaults to
|
||||
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
|
||||
# there is a discrete action head).
|
||||
target_entropy: float | None = None
|
||||
|
||||
# Update loop
|
||||
# Update-to-data ratio. Set to >1 to enable extra critic updates per env step.
|
||||
utd_ratio: int = 1
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Optimizations
|
||||
# torch.compile is currently disabled by default
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Policy config
|
||||
policy_config: PreTrainedConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||
return cls(
|
||||
policy_config=policy_cfg,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
)
|
||||
@@ -0,0 +1,672 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
MLP,
|
||||
DiscreteCritic,
|
||||
GaussianActorObservationEncoder,
|
||||
GaussianActorPolicy,
|
||||
orthogonal_init,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.types import BatchType
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
from ..base import RLAlgorithm
|
||||
from ..configs import TrainingStats
|
||||
from .configuration_sac import SACAlgorithmConfig
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
|
||||
|
||||
config_class = SACAlgorithmConfig
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: GaussianActorPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.policy_config = config.policy_config
|
||||
self.policy = policy
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self._init_critics(action_dim)
|
||||
self._init_temperature(action_dim)
|
||||
|
||||
self._device = torch.device(self.policy.config.device)
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critics(self, action_dim) -> None:
|
||||
"""Build critic ensemble, targets."""
|
||||
encoder = self.policy.encoder_critic
|
||||
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.discrete_critic_target = None
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
|
||||
|
||||
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
|
||||
"""Build target discrete critic (main network is owned by the policy)."""
|
||||
discrete_critic_target = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
input_dim=encoder.output_dim,
|
||||
output_dim=self.policy_config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
return discrete_critic_target
|
||||
|
||||
def _init_temperature(self, continuous_action_dim: int) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
total_action_dim = continuous_action_dim + (
|
||||
1 if self.policy_config.num_discrete_actions is not None else 0
|
||||
)
|
||||
self.target_entropy = -total_action_dim / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.policy.to(self._device)
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if self.discrete_critic_target is not None:
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def _critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def _discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Run one SAC training step (critic / discrete-critic / actor / temperature).
|
||||
|
||||
Pulls ``utd_ratio`` batches from ``batch_iterator``, computes the relevant
|
||||
losses, backpropagates each, and updates target networks.
|
||||
|
||||
Args:
|
||||
batch_iterator: yields batches each containing
|
||||
- ``action``: Action tensor
|
||||
- ``reward``: Reward tensor
|
||||
- ``state``: Observations tensor dict
|
||||
- ``next_state``: Next observations tensor dict
|
||||
- ``done``: Done mask tensor
|
||||
- ``observation_feature``: Optional pre-computed observation features
|
||||
- ``next_observation_feature``: Optional pre-computed next observation features
|
||||
- ``complementary_info`` (optional): per-step extras like discrete penalties
|
||||
|
||||
Returns:
|
||||
TrainingStats with per-component losses and grad norms.
|
||||
"""
|
||||
clip = self.config.grad_clip_norm
|
||||
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
|
||||
self.optimizers["discrete_critic"].step()
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": loss_critic.item()},
|
||||
grad_norms={"critic": critic_grad},
|
||||
)
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_dc.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
loss_actor = self._compute_loss_actor(fb)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
loss_temp = self._compute_loss_temperature(fb)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
loss_temp.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = loss_actor.item()
|
||||
stats.losses["loss_temperature"] = loss_temp.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
# Extract common components from batch
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
observation_features = batch.get("observation_feature")
|
||||
# Extract critic-specific components
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.policy.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self._critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self._discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self._discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self._discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def _prepare_forward_batch(
|
||||
self, batch: BatchType, *, include_complementary_info: bool = True
|
||||
) -> dict[str, Any]:
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if include_complementary_info and "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping component names ("actor", "critic", "temperature")
|
||||
to their respective Adam optimizers.
|
||||
"""
|
||||
actor_params = self.policy.get_optim_params()["actor"]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Send actor + discrete-critic state dicts."""
|
||||
state_dicts: dict[str, Any] = {
|
||||
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
self.policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
return state_dicts
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load actor + discrete-critic weights into the policy."""
|
||||
actor_sd = move_state_dict_to_device(weights["policy"], device=device)
|
||||
self.policy.actor.load_state_dict(actor_sd)
|
||||
if "discrete_critic" in weights and self.policy.discrete_critic is not None:
|
||||
discrete_sd = move_state_dict_to_device(weights["discrete_critic"], device=device)
|
||||
self.policy.discrete_critic.load_state_dict(discrete_sd)
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Algorithm-owned trainable tensors.
|
||||
|
||||
Encoder weights are stripped because they are owned by the policy
|
||||
(``policy.encoder_critic``) and already saved via ``policy.save_pretrained``.
|
||||
"""
|
||||
bundle: dict[str, torch.Tensor] = {}
|
||||
for k, v in _strip_encoder_keys(self.critic_ensemble.state_dict()).items():
|
||||
bundle[f"critic_ensemble.{k}"] = v
|
||||
for k, v in _strip_encoder_keys(self.critic_target.state_dict()).items():
|
||||
bundle[f"critic_target.{k}"] = v
|
||||
if self.discrete_critic_target is not None:
|
||||
for k, v in _strip_encoder_keys(self.discrete_critic_target.state_dict()).items():
|
||||
bundle[f"discrete_critic_target.{k}"] = v
|
||||
bundle["log_alpha"] = self.log_alpha.detach()
|
||||
return bundle
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
device: str | torch.device = "cpu",
|
||||
) -> None:
|
||||
"""In-place load of algorithm-owned tensors.
|
||||
|
||||
``log_alpha`` is restored via ``Parameter.data.copy_`` so the
|
||||
``temperature`` optimizer's reference to the parameter object stays
|
||||
valid after resume.
|
||||
"""
|
||||
critic_ensemble_state = _split_prefix(state_dict, "critic_ensemble.")
|
||||
critic_target_state = _split_prefix(state_dict, "critic_target.")
|
||||
self.critic_ensemble.load_state_dict(critic_ensemble_state, strict=False)
|
||||
self.critic_target.load_state_dict(critic_target_state, strict=False)
|
||||
|
||||
if self.discrete_critic_target is not None:
|
||||
discrete_target_state = _split_prefix(state_dict, "discrete_critic_target.")
|
||||
self.discrete_critic_target.load_state_dict(discrete_target_state, strict=False)
|
||||
|
||||
if "log_alpha" in state_dict:
|
||||
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.log_alpha.device))
|
||||
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def _strip_encoder_keys(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Drop ``encoder.*`` keys from a critic-module state dict."""
|
||||
return {k: v for k, v in state.items() if not k.startswith("encoder.")}
|
||||
|
||||
|
||||
def _split_prefix(state: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
|
||||
"""Return the subset of ``state`` whose keys start with ``prefix``, prefix-stripped."""
|
||||
return {k.removeprefix(prefix): v for k, v in state.items() if k.startswith(prefix)}
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (GaussianActorObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
@@ -97,8 +97,8 @@ class ReplayBuffer:
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
state_keys (list[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Callable | None): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
@@ -634,7 +634,7 @@ class ReplayBuffer:
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
transitions (list[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
|
||||
@@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
crop_params_dict (dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
resize_size (tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.types import BatchType
|
||||
|
||||
from .data_mixer import DataMixer, OnlineOfflineMixer
|
||||
|
||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from lerobot.types import BatchType
|
||||
|
||||
from ..buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an offline replay buffer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer | None = None,
|
||||
online_ratio: float = 1.0,
|
||||
):
|
||||
if not 0.0 <= online_ratio <= 1.0:
|
||||
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
|
||||
self.online_buffer = online_buffer
|
||||
self.offline_buffer = offline_buffer
|
||||
self.online_ratio = online_ratio
|
||||
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
if self.offline_buffer is None:
|
||||
return self.online_buffer.sample(batch_size)
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
n_offline = batch_size - n_online
|
||||
|
||||
online_batch = self.online_buffer.sample(n_online)
|
||||
offline_batch = self.offline_buffer.sample(n_offline)
|
||||
return concatenate_batch_transitions(online_batch, offline_batch)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches by composing buffer async iterators."""
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
|
||||
online_iter = self.online_buffer.get_iterator(
|
||||
batch_size=n_online,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
if self.offline_buffer is None:
|
||||
yield from online_iter
|
||||
return
|
||||
|
||||
n_offline = batch_size - n_online
|
||||
offline_iter = self.offline_buffer.get_iterator(
|
||||
batch_size=n_offline,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
while True:
|
||||
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
|
||||
@@ -17,7 +17,6 @@ import logging
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
@@ -31,6 +30,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
|
||||
from .gym_manipulator import make_robot_env
|
||||
from .train_rl import TrainRLServerPipelineConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ from lerobot.teleoperators import (
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -312,6 +313,7 @@ def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
|
||||
# Check if this is a GymHIL simulation environment
|
||||
if cfg.name == "gym_hil":
|
||||
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
|
||||
require_package("gym-hil", extra="hilserl", import_name="gym_hil")
|
||||
import gym_hil # noqa: F401
|
||||
|
||||
# Extract gripper settings with defaults
|
||||
@@ -383,10 +385,21 @@ def make_processors(
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
|
||||
# Add time limit processor if reset config exists
|
||||
if cfg.processor.reset is not None:
|
||||
env_pipeline_steps.append(
|
||||
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
)
|
||||
|
||||
env_pipeline_steps.extend(
|
||||
[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
)
|
||||
|
||||
return DataProcessorPipeline(
|
||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
), DataProcessorPipeline(
|
||||
@@ -551,8 +564,19 @@ def step_env_and_process_transition(
|
||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
|
||||
# Merge env and action-processor info: env wins for str keys, action-processor
|
||||
# wins for `TeleopEvents` enum keys
|
||||
action_info = processed_action_transition[TransitionKey.INFO]
|
||||
new_info = info.copy()
|
||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||
for key, value in action_info.items():
|
||||
if isinstance(key, TeleopEvents):
|
||||
new_info[key] = value
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
@@ -568,6 +592,24 @@ def step_env_and_process_transition(
|
||||
return new_transition
|
||||
|
||||
|
||||
def reset_and_build_transition(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
) -> EnvTransition:
|
||||
"""Reset env + processors and return the first env-processed transition."""
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
complementary_data: dict[str, Any] = {}
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
return env_processor(data=transition)
|
||||
|
||||
|
||||
def control_loop(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
@@ -593,17 +635,7 @@ def control_loop(
|
||||
print("- When not intervening, robot will stay still")
|
||||
print("- Press Ctrl+C to exit")
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = env.reset()
|
||||
complementary_data = (
|
||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
||||
)
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(data=transition)
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Determine if gripper is used
|
||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||
@@ -659,79 +691,82 @@ def control_loop(
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
step_start_time = time.perf_counter()
|
||||
try:
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {
|
||||
observation = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
# Use teleop_action if available, otherwise use the action from the transition
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
ACTION: action_to_record.cpu(),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
|
||||
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if dataset is not None:
|
||||
frame["task"] = cfg.dataset.task
|
||||
dataset.add_frame(frame)
|
||||
if cfg.mode == "record":
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
)
|
||||
frame = {
|
||||
**observation,
|
||||
ACTION: action_to_record.cpu(),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"discrete_penalty", 0.0
|
||||
)
|
||||
frame["complementary_info.discrete_penalty"] = np.array(
|
||||
[discrete_penalty], dtype=np.float32
|
||||
)
|
||||
|
||||
episode_step += 1
|
||||
if dataset is not None:
|
||||
frame["task"] = cfg.dataset.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Handle episode termination
|
||||
if terminated or truncated:
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
episode_step += 1
|
||||
|
||||
if dataset is not None:
|
||||
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
# Handle episode termination
|
||||
if terminated or truncated:
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
if dataset is not None:
|
||||
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
# Reset for new episode
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Maintain fps timing
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
# Maintain fps timing
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
finally:
|
||||
if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None:
|
||||
logging.info("Waiting for image writer to finish...")
|
||||
dataset.writer.image_writer.stop()
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
|
||||
+123
-309
@@ -51,9 +51,21 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.utils.import_utils import _grpc_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2_grpc = None
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Queue
|
||||
@@ -68,14 +80,11 @@ from lerobot.common.train_utils import (
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
bytes_to_python_object,
|
||||
@@ -84,26 +93,35 @@ from lerobot.transport.utils import (
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ALGORITHM_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.factory import make_algorithm
|
||||
from .buffer import ReplayBuffer
|
||||
from .data_sources import OnlineOfflineMixer
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
from .train_rl import TrainRLServerPipelineConfig
|
||||
from .trainer import RLTrainer
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||
# Fail fast with a friendly error if the optional ``hilserl`` extra is missing.
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
@@ -179,7 +197,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
|
||||
def start_learner_threads(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
@@ -253,7 +271,7 @@ def start_learner_threads(
|
||||
def add_actor_information_and_train(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
parameters_queue: Queue,
|
||||
@@ -266,8 +284,8 @@ def add_actor_information_and_train(
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Delegates training updates to an ``RLAlgorithm``.
|
||||
- Periodically pushes updated weights to actors.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
@@ -286,17 +304,13 @@ def add_actor_information_and_train(
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -308,7 +322,7 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
@@ -317,15 +331,17 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
# Push initial policy weights to actors
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
@@ -338,21 +354,37 @@ def add_actor_information_and_train(
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# DataMixer: online-only or online/offline 50-50 mix
|
||||
data_mixer = OnlineOfflineMixer(
|
||||
online_buffer=replay_buffer,
|
||||
offline_buffer=offline_replay_buffer,
|
||||
online_ratio=cfg.online_ratio,
|
||||
)
|
||||
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=batch_size,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg=cfg, optimizers=optimizers, algorithm=algorithm, device=device
|
||||
)
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
algorithm.optimization_step = optimization_step
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -365,7 +397,6 @@ def add_actor_information_and_train(
|
||||
transition_queue=transition_queue,
|
||||
replay_buffer=replay_buffer,
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
device=device,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
@@ -382,180 +413,20 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
|
||||
stats = trainer.training_step()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
training_infos = stats.to_log_dict()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
optimization_step = algorithm.optimization_step
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
@@ -583,7 +454,6 @@ def add_actor_information_and_train(
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
@@ -597,9 +467,12 @@ def add_actor_information_and_train(
|
||||
policy=policy,
|
||||
optimizers=optimizers,
|
||||
replay_buffer=replay_buffer,
|
||||
algorithm=algorithm,
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -607,7 +480,7 @@ def start_learner(
|
||||
parameters_queue: Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
):
|
||||
"""
|
||||
@@ -681,9 +554,12 @@ def save_training_checkpoint(
|
||||
policy: nn.Module,
|
||||
optimizers: dict[str, Optimizer],
|
||||
replay_buffer: ReplayBuffer,
|
||||
algorithm: RLAlgorithm | None = None,
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
@@ -707,6 +583,8 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
preprocessor: Optional preprocessor pipeline to save
|
||||
postprocessor: Optional postprocessor pipeline to save
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
@@ -715,7 +593,7 @@ def save_training_checkpoint(
|
||||
# Create checkpoint directory
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step)
|
||||
|
||||
# Save checkpoint
|
||||
# Save policy artifacts (pretrained_model/) + Trainer scaffolding (training_state/).
|
||||
save_checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=optimization_step,
|
||||
@@ -723,13 +601,22 @@ def save_training_checkpoint(
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR)
|
||||
os.makedirs(training_state_dir, exist_ok=True)
|
||||
training_state = {"step": optimization_step, "interaction_step": interaction_step}
|
||||
torch.save(training_state, os.path.join(training_state_dir, "training_state.pt"))
|
||||
# Algorithm-owned tensors live in their own component subfolder
|
||||
# so they can be `push_to_hub`'d independently and don't bloat the inference artifact.
|
||||
if algorithm is not None:
|
||||
algorithm.save_pretrained(checkpoint_dir / ALGORITHM_DIR)
|
||||
|
||||
# Enrich training_step.json with the RL-specific interaction_step counter so
|
||||
# both can be restored from a single file.
|
||||
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
write_json(
|
||||
{"step": optimization_step, "interaction_step": interaction_step},
|
||||
training_state_dir / TRAINING_STEP,
|
||||
)
|
||||
|
||||
# Update the "last" symlink
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
@@ -760,58 +647,6 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
# Training setup functions
|
||||
|
||||
|
||||
@@ -875,13 +710,20 @@ def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipeli
|
||||
def load_training_state(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
optimizers: Optimizer | dict[str, Optimizer],
|
||||
algorithm: RLAlgorithm | None = None,
|
||||
device: str | torch.device = "cpu",
|
||||
):
|
||||
"""
|
||||
Loads the training state (optimizers, step count, etc.) from a checkpoint.
|
||||
Loads the training state (optimizers, RNG, step + interaction step, and
|
||||
algorithm-owned tensors) from the most recent checkpoint.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): Training configuration
|
||||
optimizers (Optimizer | dict): Optimizers to load state into
|
||||
cfg: Training configuration.
|
||||
optimizers: Optimizers to load state into.
|
||||
algorithm: Algorithm whose state dict should be restored.
|
||||
Required for full main-equivalent resume;
|
||||
the policy itself is restored separately via ``make_policy``.
|
||||
device: Device on which to place loaded algorithm tensors.
|
||||
|
||||
Returns:
|
||||
tuple: (optimization_step, interaction_step) or (None, None) if not resuming
|
||||
@@ -890,20 +732,31 @@ def load_training_state(
|
||||
return None, None
|
||||
|
||||
# Construct path to the last checkpoint directory
|
||||
checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK)
|
||||
checkpoint_dir = Path(cfg.output_dir) / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
|
||||
|
||||
logging.info(f"Loading training state from {checkpoint_dir}")
|
||||
|
||||
try:
|
||||
# Use the utility function from train_utils which loads the optimizer state
|
||||
step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None)
|
||||
# Restore optimizers + RNG + step from the standard `training_state/` folder
|
||||
step, optimizers, _ = utils_load_training_state(checkpoint_dir, optimizers, None)
|
||||
|
||||
# Load interaction step separately from training_state.pt
|
||||
training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt")
|
||||
interaction_step = 0
|
||||
if os.path.exists(training_state_path):
|
||||
training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load
|
||||
interaction_step = training_state.get("interaction_step", 0)
|
||||
# Restore algorithm-owned tensors
|
||||
if algorithm is not None:
|
||||
algo_dir = checkpoint_dir / ALGORITHM_DIR
|
||||
if algo_dir.is_dir():
|
||||
tensors = load_safetensors(str(algo_dir / SAFETENSORS_SINGLE_FILE))
|
||||
algorithm.load_state_dict(tensors, device=device)
|
||||
logging.info(f"Loaded algorithm state from {algo_dir}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"No algorithm state found at {algo_dir}; "
|
||||
"will keep their freshly-initialised values. Adam moments restored from the "
|
||||
"old optimizer state may not match these reset parameters."
|
||||
)
|
||||
|
||||
# Read interaction_step from the enriched training_step.json
|
||||
training_step_path = checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP
|
||||
interaction_step = int(load_json(training_step_path).get("interaction_step", 0))
|
||||
|
||||
logging.info(f"Resuming from step {step}, interaction step {interaction_step}")
|
||||
return step, interaction_step
|
||||
@@ -1016,33 +869,6 @@ def initialize_offline_replay_buffer(
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
@@ -1093,19 +919,11 @@ def check_nan_in_transition(
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None:
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_dicts = algorithm.get_weights()
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
@@ -1129,9 +947,8 @@ def process_transitions(
|
||||
transition_queue: Queue,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
device: str,
|
||||
dataset_repo_id: str | None,
|
||||
shutdown_event: any,
|
||||
shutdown_event: Any, # Event
|
||||
):
|
||||
"""Process all available transitions from the queue.
|
||||
|
||||
@@ -1139,7 +956,6 @@ def process_transitions(
|
||||
transition_queue: Queue for receiving transitions from the actor
|
||||
replay_buffer: Replay buffer to add transitions to
|
||||
offline_replay_buffer: Offline replay buffer to add transitions to
|
||||
device: Device to move transitions to
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
shutdown_event: Event to signal shutdown
|
||||
"""
|
||||
@@ -1148,8 +964,6 @@ def process_transitions(
|
||||
transition_list = bytes_to_transitions(buffer=transition_list)
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition=transition, device=device)
|
||||
|
||||
# Skip transitions with NaN values
|
||||
if check_nan_in_transition(
|
||||
observations=transition["state"],
|
||||
@@ -1163,7 +977,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
TeleopEvents.IS_INTERVENTION.value
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
@@ -1172,7 +986,7 @@ def process_interaction_messages(
|
||||
interaction_message_queue: Queue,
|
||||
interaction_step_shift: int,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any,
|
||||
shutdown_event: Any, # Event
|
||||
) -> dict | None:
|
||||
"""Process all available interaction messages from the queue.
|
||||
|
||||
|
||||
@@ -18,17 +18,32 @@
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import _grpc_available
|
||||
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
if TYPE_CHECKING or _grpc_available:
|
||||
import grpc
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
_ServicerBase = services_pb2_grpc.LearnerServiceServicer
|
||||
else:
|
||||
grpc = None
|
||||
services_pb2 = None
|
||||
services_pb2_grpc = None
|
||||
receive_bytes_in_chunks = None
|
||||
send_bytes_in_chunks = None
|
||||
_ServicerBase = object
|
||||
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
class LearnerService(_ServicerBase):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
@@ -51,7 +66,9 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
self.queue_get_timeout = queue_get_timeout
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
def StreamParameters( # noqa: N802
|
||||
self, request: "services_pb2.Empty", context: "grpc.ServicerContext"
|
||||
):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
@@ -86,7 +103,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
def SendTransitions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
@@ -100,7 +117,7 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
def SendInteractions(self, request_iterator, _context: "grpc.ServicerContext"): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
@@ -114,5 +131,5 @@ class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
def Ready(self, request: "services_pb2.Empty", context: "grpc.ServicerContext"): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Top-level pipeline config for distributed RL training (actor / learner)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
from .algorithms.configs import RLAlgorithmConfig
|
||||
from .algorithms.factory import make_algorithm_config
|
||||
from .algorithms.sac import SACAlgorithmConfig # noqa: F401
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm config.
|
||||
algorithm: RLAlgorithmConfig | None = None
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline".
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer.
|
||||
online_ratio: float = 0.5
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
|
||||
if self.algorithm is None:
|
||||
self.algorithm = make_algorithm_config("sac")
|
||||
|
||||
if getattr(self.algorithm, "policy_config", None) is None:
|
||||
self.algorithm.policy_config = self.policy
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from lerobot.types import BatchType
|
||||
|
||||
from .algorithms.base import RLAlgorithm
|
||||
from .algorithms.configs import TrainingStats
|
||||
from .data_sources.data_mixer import DataMixer
|
||||
|
||||
|
||||
class RLTrainer:
|
||||
"""Unified training step orchestrator.
|
||||
|
||||
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: RLAlgorithm,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
preprocessor: Any | None = None,
|
||||
):
|
||||
self.algorithm = algorithm
|
||||
self.data_mixer = data_mixer
|
||||
self.batch_size = batch_size
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
self._iterator: Iterator[BatchType] | None = None
|
||||
|
||||
self.algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
def _build_data_iterator(self) -> Iterator[BatchType]:
|
||||
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
||||
raw = self.algorithm.configure_data_iterator(
|
||||
data_mixer=self.data_mixer,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if self._preprocessor is not None:
|
||||
return _PreprocessedIterator(raw, self._preprocessor)
|
||||
return raw
|
||||
|
||||
def reset_data_iterator(self) -> None:
|
||||
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
||||
self._iterator = None
|
||||
|
||||
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
||||
"""Swap the active data mixer, optionally resetting the iterator."""
|
||||
self.data_mixer = data_mixer
|
||||
if reset:
|
||||
self.reset_data_iterator()
|
||||
|
||||
def training_step(self) -> TrainingStats:
|
||||
"""Run one training step (algorithm-agnostic)."""
|
||||
if self._iterator is None:
|
||||
self._iterator = self._build_data_iterator()
|
||||
return self.algorithm.update(self._iterator)
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
||||
"""Apply policy preprocessing to RL observations only."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
batch["state"] = preprocessor.process_observation(observations)
|
||||
batch["next_state"] = preprocessor.process_observation(next_observations)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class _PreprocessedIterator:
|
||||
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
||||
|
||||
__slots__ = ("_raw", "_preprocessor")
|
||||
|
||||
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
|
||||
self._raw = raw_iterator
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
def __iter__(self) -> _PreprocessedIterator:
|
||||
return self
|
||||
|
||||
def __next__(self) -> BatchType:
|
||||
batch = next(self._raw)
|
||||
return preprocess_rl_batch(self._preprocessor, batch)
|
||||
@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||
clip_min: The minimum allowed gripper joint position.
|
||||
clip_max: The maximum allowed gripper joint position.
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
discrete_gripper: If True, interpret the input as a discrete class index
|
||||
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
|
||||
"""
|
||||
|
||||
speed_factor: float = 20.0
|
||||
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_vel = (gripper_vel - 1) * self.clip_max
|
||||
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
|
||||
# Negation accounts for SO100 sign (joint position increases on close).
|
||||
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
|
||||
gripper_vel = -(gripper_vel - 1) * self.clip_max
|
||||
|
||||
# Compute desired gripper position
|
||||
delta = gripper_vel * float(self.speed_factor)
|
||||
|
||||
@@ -92,6 +92,7 @@ def get_sys_info() -> dict[str, str]:
|
||||
info.update(
|
||||
{
|
||||
"PyTorch version": torch_version,
|
||||
"Torchcodec version": get_package_version("torchcodec"),
|
||||
"Is PyTorch built with CUDA support?": str(torch_cuda_available),
|
||||
"Cuda version": cuda_version,
|
||||
"GPU model": gpu_model,
|
||||
|
||||
@@ -48,7 +48,6 @@ from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -402,10 +401,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
# Only swap in the language-aware collate when the dataset actually
|
||||
# declares language columns; otherwise stay on PyTorch's default
|
||||
# collate so non-language training runs are unaffected.
|
||||
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
@@ -414,7 +409,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
@@ -104,11 +104,14 @@ class KeyboardTeleop(Teleoperator):
|
||||
|
||||
def _on_press(self, key):
|
||||
if hasattr(key, "char"):
|
||||
self.event_queue.put((key.char, True))
|
||||
key = key.char
|
||||
self.event_queue.put((key, True))
|
||||
|
||||
def _on_release(self, key):
|
||||
if hasattr(key, "char"):
|
||||
self.event_queue.put((key.char, False))
|
||||
key = key.char
|
||||
self.event_queue.put((key, False))
|
||||
|
||||
if key == keyboard.Key.esc:
|
||||
logging.info("ESC pressed, disconnecting.")
|
||||
self.disconnect()
|
||||
@@ -204,8 +207,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
# this is useful for retrieving other events like interventions for RL, episode success, etc.
|
||||
self.misc_keys_queue.put(key)
|
||||
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
@@ -256,6 +257,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
self.current_pressed.clear()
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
|
||||
@@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph
|
||||
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
|
||||
{% elif model_name == "sac" %}
|
||||
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
|
||||
{% elif model_name == "gaussian_actor" %}
|
||||
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
|
||||
{% elif model_name == "reward_classifier" %}
|
||||
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||
{% else %}
|
||||
|
||||
@@ -40,6 +40,7 @@ PolicyAction = torch.Tensor
|
||||
RobotAction = dict[str, Any]
|
||||
EnvAction = np.ndarray
|
||||
RobotObservation = dict[str, Any]
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
"""Collate function that preserves Python-list and language fields as lists.
|
||||
|
||||
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
||||
rendered-message and language fields as plain Python lists, and delegates
|
||||
every other key to PyTorch's ``default_collate``.
|
||||
"""
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
||||
# carry `messages` and half don't) is a real bug in the upstream
|
||||
# rendering step — silently filtering would hand downstream consumers a
|
||||
# preserved list shorter than the tensor batch. Raise instead so the
|
||||
# mismatch surfaces at the boundary.
|
||||
preserved: dict[str, list[Any]] = {}
|
||||
for key in _PYTHON_LIST_KEYS:
|
||||
presence = [key in sample for sample in batch]
|
||||
if not any(presence):
|
||||
continue
|
||||
if not all(presence):
|
||||
raise ValueError(
|
||||
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
||||
f"every sample in a batch must agree."
|
||||
)
|
||||
preserved[key] = [sample[key] for sample in batch]
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
for key, value in sample.items()
|
||||
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
||||
}
|
||||
for sample in batch
|
||||
]
|
||||
collated = default_collate(tensorizable)
|
||||
collated.update(preserved)
|
||||
return collated
|
||||
@@ -47,6 +47,7 @@ CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
ALGORITHM_DIR = "algorithm"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
|
||||
@@ -132,6 +132,7 @@ _faker_available = is_package_available("faker")
|
||||
_pynput_available = is_package_available("pynput")
|
||||
_pygame_available = is_package_available("pygame")
|
||||
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
|
||||
_grpc_available = is_package_available("grpcio", import_name="grpc")
|
||||
_wallx_deps_available = (
|
||||
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
|
||||
)
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
|
||||
|
||||
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
|
||||
return MessageTurn(role="user", content=content, stream="high_level")
|
||||
|
||||
|
||||
def _minimal_target_turn() -> MessageTurn:
|
||||
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
|
||||
|
||||
|
||||
# ── Message-recipe validation ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_message_recipe_validates_unknown_binding():
|
||||
with pytest.raises(ValueError, match="unknown binding"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||
_minimal_target_turn(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_message_turn_requires_a_stream():
|
||||
"""Every turn must declare a stream — None is rejected at construction.
|
||||
|
||||
Previously this only failed at render time (``_validate_rendered``);
|
||||
catching it here means a malformed recipe YAML errors at load instead
|
||||
of at the first training sample.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="missing a stream"):
|
||||
MessageTurn(role="user", content="${task}")
|
||||
|
||||
|
||||
def test_message_recipe_requires_at_least_one_target():
|
||||
with pytest.raises(ValueError, match="target"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
_minimal_message_turn(),
|
||||
MessageTurn(role="assistant", content="no target", stream="high_level"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_both_messages_and_blend():
|
||||
with pytest.raises(ValueError, match="only one"):
|
||||
TrainingRecipe(
|
||||
messages=[_minimal_message_turn(), _minimal_target_turn()],
|
||||
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_neither_messages_nor_blend():
|
||||
with pytest.raises(ValueError, match="must set one"):
|
||||
TrainingRecipe()
|
||||
|
||||
|
||||
# ── Blend validation ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_blend_must_be_non_empty():
|
||||
with pytest.raises(ValueError, match="at least one component"):
|
||||
TrainingRecipe(blend={})
|
||||
|
||||
|
||||
def test_blend_component_must_define_weight():
|
||||
with pytest.raises(ValueError, match="weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_weight_must_be_positive():
|
||||
with pytest.raises(ValueError, match="positive weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_must_define_messages():
|
||||
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
|
||||
# going through __post_init__ to exercise the blend-level validator.
|
||||
bad = TrainingRecipe.__new__(TrainingRecipe)
|
||||
bad.messages = None
|
||||
bad.bindings = None
|
||||
bad.blend = None
|
||||
bad.weight = 1.0
|
||||
with pytest.raises(ValueError, match="must define messages"):
|
||||
TrainingRecipe(blend={"a": bad})
|
||||
|
||||
|
||||
def test_blend_components_cannot_themselves_define_a_blend():
|
||||
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
|
||||
# Force-bypass the inner component's normal validation so the test
|
||||
# exercises the outer blend's "no nested blends" rule directly.
|
||||
nested = TrainingRecipe.__new__(TrainingRecipe)
|
||||
nested.messages = None
|
||||
nested.bindings = None
|
||||
nested.blend = inner.blend
|
||||
nested.weight = 1.0
|
||||
with pytest.raises(ValueError, match="cannot itself define a blend"):
|
||||
TrainingRecipe(blend={"outer": nested})
|
||||
|
||||
|
||||
# ── from_dict / from_yaml round-trips ────────────────────────────────
|
||||
|
||||
|
||||
def test_from_dict_with_nested_blend():
|
||||
recipe = TrainingRecipe.from_dict(
|
||||
{
|
||||
"blend": {
|
||||
"a": {
|
||||
"weight": 1.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
"b": {
|
||||
"weight": 2.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert recipe.blend is not None
|
||||
assert set(recipe.blend) == {"a", "b"}
|
||||
assert recipe.blend["b"].weight == 2.0
|
||||
# Inner messages were promoted to MessageTurn instances.
|
||||
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
|
||||
|
||||
|
||||
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
|
||||
yaml_text = dedent(
|
||||
"""
|
||||
bindings:
|
||||
custom: "active_at(t, style=subtask)"
|
||||
messages:
|
||||
- {role: user, content: "${task}: ${custom}", stream: high_level}
|
||||
- {role: assistant, content: "ok", stream: high_level, target: true}
|
||||
"""
|
||||
).strip()
|
||||
path = tmp_path / "recipe.yaml"
|
||||
path.write_text(yaml_text)
|
||||
|
||||
via_classmethod = TrainingRecipe.from_yaml(path)
|
||||
via_helper = load_recipe(path)
|
||||
|
||||
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
|
||||
assert via_classmethod.messages[1].target is True
|
||||
# ``load_recipe`` is just a wrapper, but assert the two paths agree
|
||||
# on the structural result so a future divergence is caught here.
|
||||
assert via_helper.bindings == via_classmethod.bindings
|
||||
assert len(via_helper.messages) == len(via_classmethod.messages)
|
||||
|
||||
|
||||
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
|
||||
path = tmp_path / "bad.yaml"
|
||||
path.write_text("- just\n- a\n- list\n")
|
||||
with pytest.raises(ValueError, match="mapping at the top level"):
|
||||
TrainingRecipe.from_yaml(path)
|
||||
@@ -385,84 +385,3 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
|
||||
assert episodes_dir.exists()
|
||||
parquet_files = list(episodes_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0
|
||||
|
||||
|
||||
# ── Tools accessor ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
|
||||
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||
|
||||
root = tmp_path / "no_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
assert meta.tools == DEFAULT_TOOLS
|
||||
# info.json on disk should NOT include a `tools` key for clean datasets
|
||||
with open(root / INFO_PATH) as f:
|
||||
info_on_disk = json.load(f)
|
||||
assert "tools" not in info_on_disk
|
||||
|
||||
|
||||
def test_tools_reads_declared_tools_from_info_json(tmp_path):
|
||||
"""A `tools` list written into info.json survives load → meta.tools.
|
||||
|
||||
Regression test for the bug where ``DatasetInfo.from_dict`` silently
|
||||
dropped the ``tools`` key (no matching dataclass field), so
|
||||
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
|
||||
what was on disk.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
|
||||
root = tmp_path / "with_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/with_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
custom_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a still image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"label": {"type": "string"}},
|
||||
"required": ["label"],
|
||||
},
|
||||
},
|
||||
}
|
||||
info_path = root / INFO_PATH
|
||||
with open(info_path) as f:
|
||||
raw = json.load(f)
|
||||
raw["tools"] = [custom_tool]
|
||||
with open(info_path, "w") as f:
|
||||
json.dump(raw, f)
|
||||
|
||||
# Reload info from disk and rebind it on the metadata object
|
||||
meta.info = load_info(root)
|
||||
assert meta.tools == [custom_tool]
|
||||
|
||||
|
||||
def test_tools_round_trip_through_dataset_info(tmp_path):
|
||||
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
|
||||
from lerobot.datasets.utils import DatasetInfo
|
||||
|
||||
raw = {
|
||||
"codebase_version": "v3.1",
|
||||
"fps": 30,
|
||||
"features": SIMPLE_FEATURES,
|
||||
"tools": [{"type": "function", "function": {"name": "say"}}],
|
||||
}
|
||||
info = DatasetInfo.from_dict(raw)
|
||||
assert info.tools == raw["tools"]
|
||||
assert info.to_dict()["tools"] == raw["tools"]
|
||||
|
||||
@@ -1691,3 +1691,68 @@ def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_d
|
||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||
|
||||
|
||||
def test_episode_filter_filters_dataset(tmp_path, lerobot_dataset_factory):
|
||||
"""episode_filter on LeRobotDataset narrows the loaded dataset to matching episodes."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
threshold = sorted(lengths)[len(lengths) // 2]
|
||||
expected_eps = [i for i, length in enumerate(lengths) if length >= threshold]
|
||||
expected_frames = sum(lengths[i] for i in expected_eps)
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
assert filtered.num_frames == expected_frames
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_intersects_with_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""When both episodes and episode_filter are given to LeRobotDataset, the result is their intersection."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=8, total_frames=200)
|
||||
lengths = dataset.meta.episodes["length"]
|
||||
candidates = [0, 2, 4, 6]
|
||||
candidate_lengths = [lengths[i] for i in candidates]
|
||||
threshold = sorted(candidate_lengths)[len(candidate_lengths) // 2]
|
||||
expected_eps = [i for i in candidates if lengths[i] >= threshold]
|
||||
|
||||
filtered = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=candidates,
|
||||
episode_filter=lambda ep: ep["length"] >= threshold,
|
||||
)
|
||||
|
||||
assert filtered.num_episodes == len(expected_eps)
|
||||
seen_eps = {filtered[i]["episode_index"].item() for i in range(len(filtered))}
|
||||
assert seen_eps == set(expected_eps)
|
||||
|
||||
|
||||
def test_episode_filter_no_match_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""An empty match in LeRobotDataset's episode_filter raises a ValueError rather than silently returning an empty dataset."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||
|
||||
with pytest.raises(ValueError, match=r"The episode filter did not match any episode"):
|
||||
LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["length"] < 0,
|
||||
)
|
||||
|
||||
|
||||
def test_episode_filter_unknown_key_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""A predicate referencing a column absent from meta.episodes surfaces a clear KeyError."""
|
||||
dataset = lerobot_dataset_factory(root=tmp_path / "test", total_episodes=4, total_frames=100)
|
||||
|
||||
with pytest.raises(KeyError, match="not_a_real_field"):
|
||||
LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episode_filter=lambda ep: ep["not_a_real_field"] > 0,
|
||||
)
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import numpy as np # noqa: E402
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow as pa # noqa: E402
|
||||
|
||||
from lerobot.datasets import LeRobotDataset # noqa: E402
|
||||
from lerobot.datasets.io_utils import write_info # noqa: E402
|
||||
from lerobot.datasets.language import ( # noqa: E402
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
VIEW_DEPENDENT_STYLES,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
language_events_arrow_type,
|
||||
language_feature_info,
|
||||
language_persistent_arrow_type,
|
||||
validate_camera_field,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_DATA_PATH # noqa: E402
|
||||
|
||||
|
||||
def test_language_arrow_schema_has_expected_fields():
|
||||
persistent_row_type = language_persistent_arrow_type().value_type
|
||||
event_row_type = language_events_arrow_type().value_type
|
||||
|
||||
assert isinstance(persistent_row_type, pa.StructType)
|
||||
assert persistent_row_type.names == [
|
||||
"role",
|
||||
"content",
|
||||
"style",
|
||||
"timestamp",
|
||||
"camera",
|
||||
"tool_calls",
|
||||
]
|
||||
|
||||
assert isinstance(event_row_type, pa.StructType)
|
||||
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
|
||||
|
||||
|
||||
def test_style_registry_routes_columns():
|
||||
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
|
||||
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
||||
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||
|
||||
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||
assert column_for_style("trace") == LANGUAGE_EVENTS
|
||||
assert column_for_style(None) == LANGUAGE_EVENTS
|
||||
|
||||
|
||||
def test_view_dependent_styles():
|
||||
# motion lives in PERSISTENT_STYLES and is described in robot-frame
|
||||
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
|
||||
# (event) and trace (event, pixel-trajectory) carry a camera tag.
|
||||
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
|
||||
assert is_view_dependent_style("vqa")
|
||||
assert is_view_dependent_style("trace")
|
||||
assert not is_view_dependent_style("motion")
|
||||
assert not is_view_dependent_style("subtask")
|
||||
assert not is_view_dependent_style("plan")
|
||||
assert not is_view_dependent_style("interjection")
|
||||
assert not is_view_dependent_style(None)
|
||||
|
||||
|
||||
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
|
||||
validate_camera_field("vqa", "observation.images.top")
|
||||
validate_camera_field("trace", "observation.images.front")
|
||||
with pytest.raises(ValueError, match="view-dependent"):
|
||||
validate_camera_field("vqa", None)
|
||||
with pytest.raises(ValueError, match="view-dependent"):
|
||||
validate_camera_field("trace", "")
|
||||
|
||||
|
||||
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
|
||||
validate_camera_field("subtask", None)
|
||||
validate_camera_field("plan", None)
|
||||
validate_camera_field("memory", None)
|
||||
validate_camera_field("motion", None)
|
||||
validate_camera_field("interjection", None)
|
||||
validate_camera_field(None, None)
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("subtask", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("motion", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("interjection", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field(None, "observation.images.top")
|
||||
|
||||
|
||||
def test_unknown_style_rejected():
|
||||
with pytest.raises(ValueError, match="Unknown language style"):
|
||||
column_for_style("surprise")
|
||||
|
||||
|
||||
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
|
||||
root = tmp_path / "language_dataset"
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=root,
|
||||
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
|
||||
use_videos=False,
|
||||
)
|
||||
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
persistent = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
]
|
||||
event = {
|
||||
"role": "user",
|
||||
"content": "what is visible?",
|
||||
"style": "vqa",
|
||||
"camera": "observation.images.top",
|
||||
"tool_calls": None,
|
||||
}
|
||||
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
df = pd.read_parquet(data_path)
|
||||
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
|
||||
df[LANGUAGE_EVENTS] = [[event], []]
|
||||
df.to_parquet(data_path)
|
||||
|
||||
info = dataset.meta.info
|
||||
info["features"].update(language_feature_info())
|
||||
write_info(info, root)
|
||||
|
||||
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
|
||||
|
||||
first = reloaded[0]
|
||||
second = reloaded[1]
|
||||
assert first[LANGUAGE_PERSISTENT] == persistent
|
||||
assert first[LANGUAGE_EVENTS] == [event]
|
||||
assert second[LANGUAGE_PERSISTENT] == persistent
|
||||
assert second[LANGUAGE_EVENTS] == []
|
||||
@@ -1,417 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.datasets.language_render import ( # noqa: E402
|
||||
EMITTED_AT_TOLERANCE_S,
|
||||
active_at,
|
||||
emitted_at,
|
||||
nth_next,
|
||||
nth_prev,
|
||||
render_sample,
|
||||
)
|
||||
|
||||
|
||||
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"timestamp": timestamp,
|
||||
"camera": camera,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
def event_row(role, content, style, tool_calls=None, camera=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"camera": camera,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
PERSISTENT = [
|
||||
persistent_row("assistant", "plan 0", "plan", 0.0),
|
||||
persistent_row("assistant", "memory 0", "memory", 0.0),
|
||||
persistent_row("assistant", "subtask 0", "subtask", 0.0),
|
||||
persistent_row("assistant", "memory 1", "memory", 1.0),
|
||||
persistent_row("assistant", "subtask 1", "subtask", 1.0),
|
||||
]
|
||||
EVENTS_AT_1 = [
|
||||
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
|
||||
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
|
||||
]
|
||||
EVENTS_AT_2 = [
|
||||
event_row("user", "skip wiping", "interjection"),
|
||||
event_row(
|
||||
"assistant",
|
||||
None,
|
||||
None,
|
||||
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
||||
),
|
||||
]
|
||||
# Same emission tick, two cameras: triggers per-camera disambiguation in
|
||||
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
|
||||
# (vqa, user) + (vqa, assistant) pair per camera.
|
||||
EVENTS_AT_3_TWO_CAMERAS = [
|
||||
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
|
||||
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
|
||||
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
|
||||
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
|
||||
]
|
||||
|
||||
|
||||
def test_resolver_temporal_semantics():
|
||||
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
||||
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
||||
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
|
||||
assert (
|
||||
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
|
||||
== '{"count": 2}'
|
||||
)
|
||||
|
||||
|
||||
def test_persistent_relative_resolvers_reject_event_styles():
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
active_at(1.0, persistent=PERSISTENT, style="vqa")
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
|
||||
|
||||
|
||||
def test_nth_prev_and_next():
|
||||
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
|
||||
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
|
||||
|
||||
|
||||
def test_substitution_if_present_multimodal_and_tool_calls():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${task}: ${interjection}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
bindings={"plan": "active_at(t, style=plan)"},
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_2,
|
||||
t=2.0,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
|
||||
assert rendered["messages"][1]["content"] == "plan 0"
|
||||
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
|
||||
assert rendered["message_streams"] == ["high_level", "high_level"]
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_exact_event_miss_returns_none_when_target_skips():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
|
||||
)
|
||||
|
||||
|
||||
def test_deterministic_blend_sampling():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"a": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"b": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
first = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
second = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_emitted_at_filters_vqa_by_camera():
|
||||
top = emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
camera="observation.images.top",
|
||||
)
|
||||
wrist = emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
camera="observation.images.wrist",
|
||||
)
|
||||
assert top["content"] == '{"count": 3}'
|
||||
assert wrist["content"] == '{"count": 1}'
|
||||
|
||||
|
||||
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
||||
with pytest.raises(ValueError, match="Ambiguous resolver"):
|
||||
emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
|
||||
return TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
|
||||
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("camera", "expected_query", "expected_answer"),
|
||||
[
|
||||
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
|
||||
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
|
||||
],
|
||||
)
|
||||
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
|
||||
rendered = render_sample(
|
||||
recipe=_vqa_subrecipe(camera),
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
|
||||
assert rendered["messages"][0]["content"][0]["feature"] == camera
|
||||
assert rendered["messages"][0]["content"][1]["text"] == expected_query
|
||||
assert rendered["messages"][1]["content"] == expected_answer
|
||||
|
||||
|
||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
rephrasings = [
|
||||
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
|
||||
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
|
||||
# No explicit task override → resolver consults persistent rows.
|
||||
seen: set[str] = set()
|
||||
for sample_idx in range(64):
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=sample_idx,
|
||||
dataset_ctx={"task": "canonical kitchen task"},
|
||||
)
|
||||
seen.add(rendered["messages"][0]["content"])
|
||||
# Every rephrasing should be reachable across enough samples.
|
||||
assert seen == {r["content"] for r in rephrasings}
|
||||
# Same sample_idx → same pick (determinism).
|
||||
a = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
b = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||
|
||||
|
||||
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT, # no task_aug rows
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "clean the kitchen"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "clean the kitchen"
|
||||
|
||||
|
||||
def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
rephrasings = [
|
||||
persistent_row("user", "rephrased one", "task_aug", 0.0),
|
||||
persistent_row("user", "rephrased two", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="explicit override wins",
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
|
||||
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
|
||||
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
|
||||
line up with the parquet-stored timestamp.
|
||||
"""
|
||||
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
|
||||
# Half a tolerance window — bit-different float, comfortably inside
|
||||
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
|
||||
assert inside is not None and inside["content"] == "memo"
|
||||
|
||||
# Just past the window — no match
|
||||
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
|
||||
assert outside is None
|
||||
|
||||
|
||||
def test_render_sample_rejects_non_dict_language_rows():
|
||||
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||
|
||||
A pipeline that hands the renderer a non-dict (e.g. a stray string)
|
||||
is a real upstream bug — silent skipping would let it propagate.
|
||||
"""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
with pytest.raises(TypeError, match="must be dictionaries"):
|
||||
render_sample(
|
||||
recipe=recipe,
|
||||
persistent=["not a dict"],
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="x",
|
||||
)
|
||||
|
||||
|
||||
def test_low_level_branch_renders_active_subtask():
|
||||
low_level = TrainingRecipe(
|
||||
blend={
|
||||
"low": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||
stream="high_level",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="low_level",
|
||||
target=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=low_level,
|
||||
persistent=PERSISTENT,
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
||||
assert rendered["message_streams"][-1] == "low_level"
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for subtask functionality in LeRobotDataset.
|
||||
|
||||
These tests verify that:
|
||||
- Subtask information is correctly loaded from datasets that have subtask data
|
||||
- The __getitem__ method correctly adds subtask strings to returned items
|
||||
- Subtask handling gracefully handles missing data
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class TestSubtaskDataset:
|
||||
"""Tests for subtask handling in LeRobotDataset."""
|
||||
|
||||
@pytest.fixture
|
||||
def subtask_dataset(self):
|
||||
"""Load the test subtask dataset from the hub."""
|
||||
# Use lerobot/pusht-subtask dataset with episode 1
|
||||
return LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
|
||||
def test_subtask_dataset_loads(self, subtask_dataset):
|
||||
"""Test that the subtask dataset loads successfully."""
|
||||
assert subtask_dataset is not None
|
||||
assert len(subtask_dataset) > 0
|
||||
|
||||
def test_subtask_metadata_loaded(self, subtask_dataset):
|
||||
"""Test that subtask metadata is loaded when present in dataset."""
|
||||
# The dataset should have subtasks metadata loaded
|
||||
assert subtask_dataset.meta.subtasks is not None
|
||||
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
|
||||
|
||||
def test_subtask_index_in_features(self, subtask_dataset):
|
||||
"""Test that subtask_index is a feature when dataset has subtasks."""
|
||||
assert "subtask_index" in subtask_dataset.features
|
||||
|
||||
def test_getitem_returns_subtask_string(self, subtask_dataset):
|
||||
"""Test that __getitem__ correctly adds subtask string to returned item."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Subtask should be present in the returned item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
assert len(item["subtask"]) > 0 # Should not be empty
|
||||
|
||||
def test_getitem_has_subtask_index(self, subtask_dataset):
|
||||
"""Test that __getitem__ includes subtask_index."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
assert "subtask_index" in item
|
||||
assert isinstance(item["subtask_index"], torch.Tensor)
|
||||
|
||||
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
|
||||
"""Test that subtask_index correctly maps to a subtask in metadata."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
assert item["subtask"] == subtask_from_metadata
|
||||
|
||||
def test_all_items_have_subtask(self, subtask_dataset):
|
||||
"""Test that all items in the dataset have subtask information."""
|
||||
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
|
||||
item = subtask_dataset[i]
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
def test_task_and_subtask_coexist(self, subtask_dataset):
|
||||
"""Test that both task and subtask are present in returned items."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Both task and subtask should be present
|
||||
assert "task" in item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["task"], str)
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
|
||||
class TestSubtaskDatasetMissing:
|
||||
"""Tests for graceful handling when subtask data is missing."""
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Create a dataset without subtask information."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
|
||||
|
||||
# Add some frames and save
|
||||
for _ in range(5):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Reload the dataset
|
||||
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
def test_no_subtask_in_features(self, dataset_without_subtasks):
|
||||
"""Test that subtask_index is not in features when not provided."""
|
||||
assert "subtask_index" not in dataset_without_subtasks.features
|
||||
|
||||
def test_getitem_without_subtask(self, dataset_without_subtasks):
|
||||
"""Test that __getitem__ works when subtask is not present."""
|
||||
item = dataset_without_subtasks[0]
|
||||
|
||||
# Item should still be retrievable
|
||||
assert item is not None
|
||||
assert "state" in item
|
||||
assert "task" in item
|
||||
|
||||
# Subtask should NOT be present
|
||||
assert "subtask" not in item
|
||||
|
||||
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
|
||||
"""Test that subtasks metadata is None when not present."""
|
||||
assert dataset_without_subtasks.meta.subtasks is None
|
||||
|
||||
|
||||
class TestSubtaskEdgeCases:
|
||||
"""Edge case tests for subtask handling."""
|
||||
|
||||
def test_subtask_with_multiple_episodes(self):
|
||||
"""Test subtask handling with multiple episodes if available."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
# Check first and last items have valid subtasks
|
||||
first_item = dataset[0]
|
||||
last_item = dataset[len(dataset) - 1]
|
||||
|
||||
assert "subtask" in first_item
|
||||
assert "subtask" in last_item
|
||||
assert isinstance(first_item["subtask"], str)
|
||||
assert isinstance(last_item["subtask"], str)
|
||||
|
||||
def test_subtask_index_consistency(self):
|
||||
"""Test that same subtask_index returns same subtask string."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
if len(dataset) < 2:
|
||||
pytest.skip("Dataset too small for this test")
|
||||
|
||||
# Collect subtask_index to subtask mappings
|
||||
subtask_map = {}
|
||||
for i in range(min(len(dataset), 10)):
|
||||
item = dataset[i]
|
||||
idx = item["subtask_index"].item()
|
||||
subtask = item["subtask"]
|
||||
|
||||
if idx in subtask_map:
|
||||
# Same index should always return same subtask
|
||||
assert subtask_map[idx] == subtask, (
|
||||
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
|
||||
)
|
||||
else:
|
||||
subtask_map[idx] = subtask
|
||||
@@ -17,19 +17,19 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import (
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
def test_gaussian_actor_config_default_initialization():
|
||||
config = GaussianActorConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
@@ -55,9 +55,6 @@ def test_sac_config_default_initialization():
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.vision_encoder_name is None
|
||||
@@ -66,6 +63,8 @@ def test_sac_config_default_initialization():
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
@@ -73,20 +72,6 @@ def test_sac_config_default_initialization():
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
@@ -105,11 +90,6 @@ def test_sac_config_default_initialization():
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
@@ -135,7 +115,6 @@ def test_sac_config_default_initialization():
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
@@ -175,22 +154,22 @@ def test_concurrency_config():
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
def test_gaussian_actor_config_custom_initialization():
|
||||
config = GaussianActorConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
latent_dim=128,
|
||||
state_encoder_hidden_dim=128,
|
||||
num_discrete_actions=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
assert config.latent_dim == 128
|
||||
assert config.state_encoder_hidden_dim == 128
|
||||
assert config.num_discrete_actions == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -198,7 +177,7 @@ def test_validate_features():
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -209,7 +188,7 @@ def test_validate_features_missing_observation():
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -0,0 +1,528 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
from torch import Tensor, nn # noqa: E402
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig # noqa: E402
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy # noqa: E402
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE # noqa: E402
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed # noqa: E402
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 42
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_gaussian_actor_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
GaussianActorPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> GaussianActorConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
ACTION: {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> GaussianActorConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
|
||||
"""Helper to create policy + algorithm pair for tests that need critics."""
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
return algorithm, policy
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
# squeeze(0) removes batch dim when batch_size==1
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
def test_gaussian_actor_policy_select_action_with_discrete():
|
||||
"""select_action should return continuous + discrete actions."""
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.num_discrete_actions = 3
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
|
||||
# Squeeze to unbatched (single observation)
|
||||
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
with torch.no_grad():
|
||||
output = policy.forward(batch)
|
||||
assert "action" in output
|
||||
assert "log_prob" in output
|
||||
assert "action_mean" in output
|
||||
assert output["action"].shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_gaussian_actor_training_through_sac(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
temp_loss = algorithm._compute_loss_temperature(forward_batch)
|
||||
assert temp_loss.item() is not None
|
||||
assert temp_loss.shape == ()
|
||||
algorithm.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
algorithm.optimizers["temperature"].step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_gaussian_actor_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_gaussian_actor_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_gaussian_actor_training_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_gaussian_actor_training_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
config.num_discrete_actions = 5
|
||||
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
|
||||
assert discrete_critic_loss.shape == ()
|
||||
algorithm.optimizers["discrete_critic"].zero_grad()
|
||||
discrete_critic_loss.backward()
|
||||
algorithm.optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
# Policy.select_action now handles both continuous + discrete
|
||||
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
|
||||
assert selected_action.shape[-1] == continuous_action_dim + 1
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy():
|
||||
"""Target entropy is an SAC hyperparameter and lives on the algorithm."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algorithm, _ = _make_algorithm(config)
|
||||
assert algorithm.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algorithm, _ = _make_algorithm(config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
|
||||
|
||||
def test_sac_algorithm_temperature():
|
||||
import math
|
||||
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
assert algorithm.temperature == pytest.approx(1.0)
|
||||
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
assert algorithm.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_algorithm_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algo_config.critic_target_update_weight = 1.0
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
for p in algorithm.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
algorithm._update_target_networks()
|
||||
for p in algorithm.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algo_config.num_critics = num_critics
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_gaussian_actor_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_gaussian_actor_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_gaussian_actor_save_and_load_discrete"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_discrete_actions = 3
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
|
||||
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert loaded_policy.discrete_critic is not None
|
||||
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
@@ -1,546 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRANSFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
seed = 42
|
||||
set_seed(seed)
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor:
|
||||
return torch.randn(batch_size, action_dim)
|
||||
|
||||
|
||||
def create_default_train_batch(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_state(batch_size, state_dim),
|
||||
"next_state": create_dummy_state(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_train_batch_with_visual_input(
|
||||
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
|
||||
) -> dict[str, Tensor]:
|
||||
return {
|
||||
ACTION: create_dummy_action(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
|
||||
"done": torch.randn(batch_size),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
ACTION: {
|
||||
"min": [0.0] * continuous_action_dim,
|
||||
"max": [1.0] * continuous_action_dim,
|
||||
},
|
||||
},
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
assert discrete_critic_loss.shape == ()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
"""Test that temperature property is always in sync with log_alpha."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
# Temperature property automatically reflects log_alpha changes
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 10
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
+24
-24
@@ -21,8 +21,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SAC configuration for testing."""
|
||||
config = SACConfig()
|
||||
config = GaussianActorConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
}
|
||||
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_sac_processor_normalization_modes():
|
||||
def test_gaussian_actor_processor_normalization_modes():
|
||||
"""Test that SAC processor correctly handles different normalization modes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_cuda():
|
||||
def test_gaussian_actor_processor_cuda():
|
||||
"""Test SAC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_accelerate_scenario():
|
||||
def test_gaussian_actor_processor_accelerate_scenario():
|
||||
"""Test SAC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_sac_processor_multi_gpu():
|
||||
def test_gaussian_actor_processor_multi_gpu():
|
||||
"""Test SAC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_sac_processor_without_stats():
|
||||
def test_gaussian_actor_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_sac_processor_save_and_load():
|
||||
def test_gaussian_actor_processor_save_and_load():
|
||||
"""Test saving and loading SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_mixed_precision():
|
||||
def test_gaussian_actor_processor_mixed_precision():
|
||||
"""Test SAC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_sac_processor_batch_data():
|
||||
def test_gaussian_actor_processor_batch_data():
|
||||
"""Test SAC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
|
||||
|
||||
|
||||
def test_sac_processor_edge_cases():
|
||||
def test_gaussian_actor_processor_edge_cases():
|
||||
"""Test SAC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
preprocessor, _ = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -1804,13 +1804,15 @@ def test_stats_override_preservation_in_load_state_dict():
|
||||
override_normalizer.stats[key][stat_name], original_stats[key][stat_name]
|
||||
), f"Stats for {key}.{stat_name} should not match original stats"
|
||||
|
||||
# Verify that _tensor_stats are also correctly set to match the override stats
|
||||
# Verify that _tensor_stats values match the override stats
|
||||
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats
|
||||
expected_tensor_stats = to_tensor(override_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor):
|
||||
torch.testing.assert_close(
|
||||
override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name]
|
||||
override_normalizer._tensor_stats[key][stat_name].squeeze(),
|
||||
expected_tensor_stats[key][stat_name].squeeze(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1849,12 +1851,16 @@ def test_stats_without_override_loads_normally():
|
||||
# Stats should now match the original stats (normal behavior)
|
||||
# Check that all keys and values match
|
||||
assert set(new_normalizer.stats.keys()) == set(original_stats.keys())
|
||||
# Note: visual stats are reshaped from (C,) to (C,1,1) by _reshape_visual_stats,
|
||||
# so we squeeze before comparing values.
|
||||
for key in original_stats:
|
||||
assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys())
|
||||
for stat_name in original_stats[key]:
|
||||
np.testing.assert_allclose(
|
||||
new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6
|
||||
)
|
||||
actual = new_normalizer.stats[key][stat_name]
|
||||
expected = original_stats[key][stat_name]
|
||||
if hasattr(actual, "squeeze"):
|
||||
actual = actual.squeeze()
|
||||
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)
|
||||
|
||||
|
||||
def test_stats_explicit_provided_flag_detection():
|
||||
@@ -2075,8 +2081,9 @@ def test_stats_reconstruction_after_load_state_dict():
|
||||
assert ACTION in new_normalizer.stats
|
||||
|
||||
# Check that values are correct (converted back from tensors)
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2])
|
||||
# Note: visual stats are reshaped to (C,1,1), so we squeeze before comparing
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"].squeeze(), [0.5, 0.5, 0.5])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"].squeeze(), [0.2, 0.2, 0.2])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0])
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.processor.converters import create_transition # noqa: E402
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
|
||||
from lerobot.types import TransitionKey # noqa: E402
|
||||
|
||||
|
||||
def test_render_messages_step_noops_without_language_columns():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "do it"})
|
||||
|
||||
assert RenderMessagesStep(recipe)(transition) == transition
|
||||
|
||||
|
||||
def test_render_messages_step_renders_and_drops_raw_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "do it",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach carefully",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
"language_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "language_persistent" not in data
|
||||
assert "language_events" not in data
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
@@ -36,9 +35,6 @@ def test_classifier_output():
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
@@ -80,9 +76,6 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
@@ -122,9 +115,6 @@ def test_multiclass_classifier():
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_default_device():
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
@@ -141,9 +131,6 @@ def test_default_device():
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
@pytest.mark.skip(
|
||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
||||
)
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
|
||||
@@ -22,12 +22,14 @@ import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
@@ -79,7 +81,7 @@ def cfg():
|
||||
|
||||
port = find_free_port()
|
||||
|
||||
policy_cfg = SACConfig()
|
||||
policy_cfg = GaussianActorConfig()
|
||||
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
|
||||
policy_cfg.actor_learner_config.learner_port = port
|
||||
policy_cfg.concurrency.actor = "threads"
|
||||
@@ -299,3 +301,164 @@ def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
assert received_params.keys() == input_params.keys()
|
||||
for key in input_params:
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
|
||||
|
||||
def test_learner_algorithm_wiring():
|
||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
assert "temperature" in optimizers
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def batch_iterator():
|
||||
while True:
|
||||
yield {
|
||||
ACTION: torch.randn(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"done": torch.zeros(batch_size),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
stats = algorithm.update(batch_iterator())
|
||||
assert "loss_critic" in stats.losses
|
||||
|
||||
# get_weights -> state_to_bytes round-trip
|
||||
weights = algorithm.get_weights()
|
||||
assert len(weights) > 0
|
||||
serialized = state_to_bytes(weights)
|
||||
assert isinstance(serialized, bytes)
|
||||
assert len(serialized) > 0
|
||||
|
||||
# RLTrainer with DataMixer
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=50,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for _ in range(50):
|
||||
replay_buffer.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(action_dim),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
trainer_stats = trainer.training_step()
|
||||
assert "loss_critic" in trainer_stats.losses
|
||||
|
||||
|
||||
def test_initial_and_periodic_weight_push_consistency():
|
||||
"""Both initial and periodic weight pushes should use algorithm.get_weights()
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
# Simulate initial push (same code path the learner now uses)
|
||||
initial_weights = algorithm.get_weights()
|
||||
initial_bytes = state_to_bytes(initial_weights)
|
||||
|
||||
# Simulate periodic push
|
||||
periodic_weights = algorithm.get_weights()
|
||||
periodic_bytes = state_to_bytes(periodic_weights)
|
||||
|
||||
initial_decoded = bytes_to_state_dict(initial_bytes)
|
||||
periodic_decoded = bytes_to_state_dict(periodic_bytes)
|
||||
|
||||
assert initial_decoded.keys() == periodic_decoded.keys()
|
||||
|
||||
|
||||
def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
# Actor side: no optimizers
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(cfg=SACAlgorithmConfig.from_policy_config(sac_cfg), policy=policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
# select_action should work
|
||||
obs = {OBS_STATE: torch.randn(state_dim)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
# Simulate receiving weights from learner
|
||||
fake_weights = algorithm.get_weights()
|
||||
algorithm.load_weights(fake_weights, device="cpu")
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer # noqa: E402
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer # noqa: E402
|
||||
from lerobot.utils.constants import OBS_STATE # noqa: E402
|
||||
|
||||
|
||||
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
|
||||
buf = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for i in range(capacity):
|
||||
buf.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(2),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=bool(i % 10 == 9),
|
||||
truncated=False,
|
||||
)
|
||||
return buf
|
||||
|
||||
|
||||
def test_online_only_mixer_sample():
|
||||
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
|
||||
batch = mixer.sample(batch_size=8)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 8
|
||||
assert batch["action"].shape[0] == 8
|
||||
assert batch["reward"].shape[0] == 8
|
||||
|
||||
|
||||
def test_online_only_mixer_ratio_one():
|
||||
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_sample():
|
||||
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
|
||||
online = _make_buffer(capacity=50)
|
||||
offline = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(
|
||||
online_buffer=online,
|
||||
offline_buffer=offline,
|
||||
online_ratio=0.5,
|
||||
)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
assert batch["action"].shape[0] == 10
|
||||
# 5 from online, 5 from offline (approx)
|
||||
assert batch["reward"].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_iterator():
|
||||
"""get_iterator yields batches of the requested size."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
|
||||
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
|
||||
batch1 = next(it)
|
||||
batch2 = next(it)
|
||||
assert batch1["state"][OBS_STATE].shape[0] == 4
|
||||
assert batch2["state"][OBS_STATE].shape[0] == 4
|
||||
@@ -20,7 +20,7 @@ from queue import Queue
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user