mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5753f8c18b | |||
| 97bd373d15 | |||
| 10a73e3c95 | |||
| 27c9288b24 | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec | |||
| c8225d749a | |||
| 68f869b7a0 | |||
| 4119ad4d10 | |||
| 750358895b | |||
| bc4d0db8f4 | |||
| 45e273b806 | |||
| 8b5f56b63c | |||
| 9f1ee224cb | |||
| 885f55ef04 | |||
| bba996ef8d | |||
| 162b07512a | |||
| 0509ea05df | |||
| de1a9e5ad9 | |||
| 6803439f22 | |||
| 90d1e70da2 | |||
| a35ac22afd | |||
| fd7fed08e2 | |||
| 2e9cd87bbd | |||
| 0c3cc4c9d6 | |||
| 6caeac9d07 | |||
| 1d6810b814 | |||
| de9af57475 | |||
| d1b1c5c8cf | |||
| 741c2d0a39 | |||
| 19fe315971 | |||
| 364750ada2 | |||
| 342d223706 | |||
| e3b203e5a7 | |||
| 906b585826 | |||
| b568c41355 | |||
| b8ad81bf39 | |||
| 24017e960c | |||
| e86f5af5bf | |||
| 5c98e80430 | |||
| f65f3f7a4a | |||
| 8194897994 | |||
| 9f437d86b6 | |||
| b74a551d38 | |||
| c0a2e9814d | |||
| bac4f61eae | |||
| f4b834844e | |||
| dfdc48a7f1 | |||
| 6a8878a639 | |||
| d38eb89f71 | |||
| 7ab4936b1b | |||
| ca8c60a0ed | |||
| 3c15fd8537 | |||
| 5ebbdf3d05 | |||
| 6e035fb169 | |||
| 01dcb4c292 | |||
| bd9619dfc3 | |||
| 0a4a7c40ad | |||
| ca9028ad64 | |||
| 9db9c35cb4 | |||
| fe96b28c74 | |||
| 2438df1307 | |||
| f218d5ab30 | |||
| 04125492e4 | |||
| e963e5a0c4 |
@@ -105,7 +105,7 @@ lerobot-train \
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
# Video benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
|
||||
| parameter | values |
|
||||
| ----------- | ------------------------------------------------------------ |
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
|
||||
- `pyav`
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
## Metrics
|
||||
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
## Install
|
||||
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
## Adding a video decoder
|
||||
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
For a quick run, you can try these parameters:
|
||||
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
|
||||
We ran the benchmark with the following parameters:
|
||||
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
@@ -1,488 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -3,10 +3,14 @@
|
||||
title: LeRobot
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: cheat-sheet
|
||||
title: Cheat sheet
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
@@ -37,8 +41,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -53,12 +61,14 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: evo1
|
||||
title: EVO1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: NVIDIA GR00T
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
@@ -69,6 +79,10 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: robometer
|
||||
title: ROBOMETER
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
@@ -141,6 +155,8 @@
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
- local: rebot_b601
|
||||
title: reBot B601-DM
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
+6
-10
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/act_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_robot \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
--task="Your task description" \ # can be skipped for ACT
|
||||
--duration=60
|
||||
```
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
# Cheat sheet
|
||||
|
||||
All of the LeRobot commands in one place. If you forgot how to use a specific command or want to learn about a new one you can do it here.
|
||||
|
||||
> [!WARNING]
|
||||
> For all of the commands listed below remember to change the ports/names/ids to your own values!
|
||||
|
||||
> [!TIP]
|
||||
> Another great way to look at all the commands and get them configured for your specific setup is to use this [Jupyter Notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb).
|
||||
|
||||
### Setup and installation
|
||||
|
||||
For installation please look at [LeRobot Installation](https://huggingface.co/docs/lerobot/main/en/installation).
|
||||
|
||||
### Useful tools
|
||||
|
||||
###### Find port
|
||||
|
||||
Use this to identify which serial ports your robots are connected to. Follow the instructions in your terminal: you will be asked to unplug the USB cable and press Enter. The script will then detect and print the correct serial port for that robot.
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
###### Find cameras
|
||||
|
||||
Quickly find camera indices and verify their output. This command prints camera information to the terminal and saves test frames from each detected camera to `lerobot/outputs/captured_images`
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras
|
||||
```
|
||||
|
||||
### Calibration
|
||||
|
||||
In most cases you will need to perform calibration just once for each robot and teleoperation device. Before performing the calibration make sure that all the joints are roughly in the middle position.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm
|
||||
```
|
||||
|
||||
Make sure that you use the same IDs used during calibration later for the other scripts. That's how LeRobot finds the calibration files.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
Teleoperating with two cameras and displaying the data with Rerun.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
### Recording a dataset
|
||||
|
||||
The dataset is automatically uploaded to the server and saved under repo_id, make sure you are logged in to your HF account with CLI:
|
||||
`hf auth login`
|
||||
|
||||
You can get the token from: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{ top: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30} }" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--dataset.num_episodes=30 \
|
||||
--dataset.single_task="put the red brick in a bowl" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
While collecting the dataset you can control the process with your keyboard:
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
|
||||
- Press **Right Arrow (`→`)**: Save episode and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Delete current episode and retry.
|
||||
- Press **Escape (`ESC`)**: Stop, encode videos, and upload.
|
||||
|
||||
### Training
|
||||
|
||||
Depending on your hardware training the policy might take a few hours. That's how you train simple `ACT` policy:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
- Policy Types: `act`, `diffusion`, `smolvla`, `pi05`
|
||||
- Devices: `cuda` (NVIDIA), `mps` (Apple Silicon), `cpu`
|
||||
|
||||
If you want to fine-tune a specific model you can provide the path to the model. In this case path is enough and type can be skipped.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_dataset_test \
|
||||
--policy.path=username/the_policy_to_finetune \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=${HF_USER}/policy_test \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
> [!TIP]
|
||||
> If you are using the previous release V0.5.1 instead of `lerobot-rollout` you need to use `lerobot-record`. More information [here](https://huggingface.co/docs/lerobot/v0.5.1/en/il_robots#run-inference-and-evaluate-your-policy).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
|
||||
|
||||
- [SmolVLA](./smolvla)
|
||||
- [Pi0.5](./pi05)
|
||||
- [GR00T N1.5](./groot)
|
||||
- [GR00T N1.7](./groot)
|
||||
|
||||
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
|
||||
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=evo1` configuration through LeRobot
|
||||
- InternVL3 image/text embedding with optional FlashAttention fallback
|
||||
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with standard policy inference APIs
|
||||
|
||||
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EVO1 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1]"
|
||||
```
|
||||
|
||||
For LIBERO evaluation, install the LIBERO extra as well:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1,libero]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
|
||||
|
||||
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EVO1 expects a LeRobot dataset with:
|
||||
|
||||
- One to `policy.max_views` visual observations, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
|
||||
|
||||
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EVO1 in a LeRobot configuration, specify:
|
||||
|
||||
```python
|
||||
policy.type=evo1
|
||||
```
|
||||
|
||||
By default, a new EVO1 policy initializes its VLM from:
|
||||
|
||||
```python
|
||||
policy.vlm_model_name=OpenGVLab/InternVL3-1B
|
||||
```
|
||||
|
||||
Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
The converted LIBERO checkpoint used for this PR is available at:
|
||||
|
||||
```python
|
||||
policy.path=javadcc/evo1-libero-lerobot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
|
||||
Stage 1 freezes the VLM and trains the action head:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=evo1 \
|
||||
--policy.training_stage=stage1 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=5000 \
|
||||
--output_dir=./outputs/evo1_stage1
|
||||
```
|
||||
|
||||
### Stage 2
|
||||
|
||||
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
|
||||
--policy.training_stage=stage2 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=80000 \
|
||||
--output_dir=./outputs/evo1_stage2
|
||||
```
|
||||
|
||||
By default, `policy.training_stage` reapplies the finetuning defaults for that stage. This is important when
|
||||
starting Stage 2 from a Stage 1 checkpoint, because the Stage 1 checkpoint config stores the VLM finetuning
|
||||
flags as disabled. These stage defaults take precedence over saved or manually supplied `policy.finetune_*`
|
||||
flags unless `policy.apply_training_stage_defaults=false`, so set that flag only when manually controlling
|
||||
every finetuning flag.
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
|
||||
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
|
||||
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||
| `policy.apply_training_stage_defaults` | `true` | Reapplies stage finetuning defaults after loading a checkpoint |
|
||||
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
|
||||
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
|
||||
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Results
|
||||
|
||||
### LIBERO Object Checkpoint Conversion
|
||||
|
||||
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
|
||||
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
|
||||
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
|
||||
|
||||
| Checkpoint | Suite | Episodes | Success Rate |
|
||||
| ---------------------------- | --------------- | ---------------- | ------------ |
|
||||
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
|
||||
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
|
||||
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
|
||||
(`max_abs_diff=0.0`).
|
||||
|
||||
The published checkpoint expects the raw LIBERO camera feature names
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. To run the converted
|
||||
checkpoint with LeRobot LIBERO evaluation for the same one-episode-per-task setting, keep those camera names
|
||||
instead of the default `image`/`image2` mapping:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=javadcc/evo1-libero-lerobot \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--env.camera_name_mapping="{agentview_image: agentview_image, robot0_eye_in_hand_image: robot0_eye_in_hand_image}" \
|
||||
--env.observation_height=448 \
|
||||
--env.observation_width=448 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
|
||||
+85
-36
@@ -1,16 +1,19 @@
|
||||
# GR00T N1.5 Policy
|
||||
# GR00T Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
|
||||
@@ -28,33 +31,46 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot]
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
To use GR00T N1.7:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -87,28 +103,63 @@ accelerate launch \
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
> [!NOTE]
|
||||
> Follow our instructions for Libero usage: [Libero](./libero)
|
||||
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
### GR00T N1.7 LIBERO Checkpoints
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
|
||||
|
||||
| Suite | Checkpoint subdirectory |
|
||||
| -------------- | ----------------------- |
|
||||
| LIBERO Spatial | `libero_spatial` |
|
||||
| LIBERO Object | `libero_object` |
|
||||
| LIBERO Goal | `libero_goal` |
|
||||
| LIBERO 10 | `libero_10` |
|
||||
|
||||
Preliminary LeRobot integration results:
|
||||
|
||||
| Suite | Status | Success rate | n_episodes |
|
||||
| -------------- | ------ | -----------: | ---------: |
|
||||
| LIBERO Spatial | ✓ | ~95% | XX |
|
||||
| LIBERO Object | ✓ | XX% | XX |
|
||||
| LIBERO Goal | ✓ | XX% | XX |
|
||||
| LIBERO 10 | ✓ | XX% | XX |
|
||||
| **Average** | ✓ | **XX%** | **XX** |
|
||||
|
||||
Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.n_episodes=50
|
||||
```
|
||||
|
||||
Use `eval.n_episodes >= 50` per suite when reporting success rates.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
@@ -119,16 +170,14 @@ lerobot-record \
|
||||
}' \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
--duration=600
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
|
||||
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
|
||||
|
||||
+40
-37
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
|
||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||
```
|
||||
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
||||
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
|
||||
|
||||
**Example Configuration**
|
||||
|
||||
```json
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.20, 0.10],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"end_effector_bounds": {
|
||||
"max": [0.24, 0.2, 0.1],
|
||||
"min": [0.16, -0.08, 0.03]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
|
||||
|
||||
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
||||
|
||||
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
|
||||
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
# bounds for the end-effector in x,y,z direction
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
# maximum step size for the end-effector in x,y,z direction
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
class HILSerlProcessorConfig:
|
||||
...
|
||||
# maximum gripper position that the gripper will be open at
|
||||
max_gripper_pos: float | None = 100.0
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
|
||||
|
||||
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
||||
|
||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||
|
||||
**Collecting a Dataset for the reward classifier**
|
||||
|
||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||
|
||||
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||
|
||||
@@ -658,7 +661,7 @@ Example configuration section for data collection:
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
@@ -671,7 +674,7 @@ Example configuration section for data collection:
|
||||
|
||||
**Reward Classifier Configuration**
|
||||
|
||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
||||
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
|
||||
|
||||
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
||||
- **model_type**: `"cnn"` or `"transformer"`
|
||||
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"root": null
|
||||
},
|
||||
"policy": {
|
||||
"reward_model": {
|
||||
"type": "reward_classifier",
|
||||
"model_name": "helper2424/resnet10",
|
||||
"model_type": "cnn",
|
||||
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"dropout_rate": 0.1,
|
||||
"learning_rate": 1e-4,
|
||||
"device": "cuda",
|
||||
"use_amp": true,
|
||||
"input_features": {
|
||||
"observation.images.front": {
|
||||
"type": "VISUAL",
|
||||
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
**Configuration Setup**
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
|
||||
3. Set `dataset` to your cropped dataset
|
||||
4. Configure environment settings with crop parameters
|
||||
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
|
||||
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
+210
-109
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760431541",
|
||||
id="my_red_robot_arm",
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
)
|
||||
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--teleop.type=koch_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=my_awesome_leader_arm \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||
--robot.id=my_follower_arm \
|
||||
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||
--teleop.id=my_leader_arm \
|
||||
--display_data=true
|
||||
```
|
||||
</hfoption>
|
||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
import time
|
||||
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||
|
||||
camera_config = {
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
||||
}
|
||||
|
||||
robot_config = KochFollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="my_red_robot_arm",
|
||||
cameras=camera_config
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
teleop_config = KochLeaderConfig(
|
||||
port="/dev/tty.usbmodem58760431551",
|
||||
id="my_blue_leader_arm",
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
robot = KochFollower(robot_config)
|
||||
teleop_device = KochLeader(teleop_config)
|
||||
init_rerun(session_name="teleoperation")
|
||||
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop_device = SO101Leader(teleop_config)
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
TARGET_HZ = 30
|
||||
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||
|
||||
while True:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
observation = robot.get_observation()
|
||||
action = teleop_device.get_action()
|
||||
robot.send_action(action)
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
elapsed_time = time.perf_counter() - start_time
|
||||
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -193,7 +207,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
@@ -202,10 +216,11 @@ lerobot-record \
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -218,71 +233,56 @@ EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
|
||||
# Create robot configuration
|
||||
robot_config = SO100FollowerConfig(
|
||||
id="my_awesome_follower_arm",
|
||||
cameras={
|
||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
||||
},
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
)
|
||||
|
||||
teleop_config = SO100LeaderConfig(
|
||||
id="my_awesome_leader_arm",
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop = SO100Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
def main():
|
||||
# Create robot configuration
|
||||
robot_config = SO101FollowerConfig(
|
||||
port="/dev/tty.usbmodem5AB90687491",
|
||||
id="my_follower_arm",
|
||||
cameras={
|
||||
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
teleop_config = SO101LeaderConfig(
|
||||
port="/dev/tty.usbmodem5AB90689011",
|
||||
id="my_leader_arm",
|
||||
)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO101Follower(robot_config)
|
||||
teleop = SO101Leader(teleop_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<dataset_repo_id>",
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Create the required processors
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
@@ -291,26 +291,50 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# finalize dataset
|
||||
log_say("Finalizing dataset...")
|
||||
dataset.finalize()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
||||
##### 2. Checkpointing and Resuming
|
||||
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
||||
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 3. Recording Parameters
|
||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Train using Hugging Face Jobs
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
hf jobs run \
|
||||
--flavor a10g-small \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
-- \
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=username/dataset \
|
||||
--policy.type=act \
|
||||
--steps=5000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id=username/your_policy \
|
||||
--log_freq=100
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from huggingface_hub import run_job, get_token
|
||||
|
||||
run_name = "act_so101_hf_jobs"
|
||||
dataset_id = "username/dataset"
|
||||
user_hub_id = "username"
|
||||
|
||||
command_args = [
|
||||
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||
"--dataset.repo_id", dataset_id,
|
||||
"--policy.type", "act",
|
||||
"--steps", "5000",
|
||||
"--batch_size", "16",
|
||||
"--num_workers", "4",
|
||||
"--policy.device", "cuda",
|
||||
"--log_freq", "100",
|
||||
"--save_freq", "1000",
|
||||
"--save_checkpoint", "true",
|
||||
"--wandb.enable", "false",
|
||||
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||
]
|
||||
|
||||
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||
|
||||
job_info = run_job(
|
||||
image="huggingface/lerobot-gpu:latest",
|
||||
command=command_args,
|
||||
flavor="a10g-small",
|
||||
timeout="4h",
|
||||
secrets={"HF_TOKEN": get_token()}
|
||||
)
|
||||
|
||||
print("\n🚀 Job successfully launched!")
|
||||
print(f"🔹 Job ID: {job_info.id}")
|
||||
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||
For longer training sessions increase the timeout.
|
||||
|
||||
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -0,0 +1,29 @@
|
||||
# LeLab - LeRobot Guide
|
||||
|
||||
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||
|
||||
> [!TIP]
|
||||
> For now LeLab is compatible only with SO-ARM101
|
||||
|
||||
<Youtube id="VqyKUuW9V1g" />
|
||||
|
||||
### Installation
|
||||
|
||||
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
After install, run `lelab` from your terminal anytime to start the app.
|
||||
|
||||
### Features
|
||||
|
||||
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||
@@ -10,6 +10,7 @@ This docs will guide you to:
|
||||
- Stream datasets without downloading using `StreamingLeRobotDataset`
|
||||
- Apply image transforms for data augmentation during training
|
||||
- Migrate existing `v2.1` datasets to `v3.0`
|
||||
- Experiment with other `LeRobotDataset` formats and implementations like Lance
|
||||
|
||||
## What’s new in `v3`
|
||||
|
||||
@@ -43,7 +44,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
@@ -274,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
@@ -315,3 +316,39 @@ Dataset v3.0 uses incremental parquet writing with buffered metadata for efficie
|
||||
- Ensures the dataset is valid for loading
|
||||
|
||||
Without calling `finalize()`, your parquet files will be incomplete and the dataset won't load properly.
|
||||
|
||||
## Other formats and implementations
|
||||
|
||||
### Lance
|
||||
|
||||
Lance is a useful format for multimodal AI datasets, especially for large-scale training requiring high performance IO and random access.
|
||||
|
||||
The `lerobot-lancedb` package implements `LeRobotLanceDataset` (for JPEG images) and `LeRobotLanceVideoDataset` (for mp4 videos).
|
||||
Those two storage layouts both subclass LeRobotDataset and can provide data loading speed ups.
|
||||
|
||||
`LeRobotLanceDataset` is a drop-in replacement for `LeRobotDataset`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot_lancedb import LeRobotLanceDataset, LeRobotLanceVideoDataset
|
||||
|
||||
cfg = DiffusionConfig(...)
|
||||
meta = LeRobotDatasetMetadata(root=local_dataset_path) # or use repo_id=... to load metadata from the Hub
|
||||
delta_timestamps = {...}
|
||||
|
||||
# Use LeRobotLanceDataset for image datasets
|
||||
dataset = LeRobotLanceDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
# Or use LeRobotLanceVideoDataset for video datasets:
|
||||
dataset = LeRobotLanceVideoDataset(
|
||||
root=local_dataset_path, # or use repo_id=... to stream from the Hub
|
||||
delta_timestamps=delta_timestamps,
|
||||
return_uint8=True,
|
||||
)
|
||||
```
|
||||
|
||||
Join the discussion on [Github](https://github.com/huggingface/lerobot/issues/3608) and explore the `lerobot-lancedb` documentation [here](https://lancedb.github.io/lerobot-lancedb/).
|
||||
|
||||
@@ -0,0 +1,433 @@
|
||||
# MolmoAct2 Policy
|
||||
|
||||
MolmoAct2 is the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into the LeRobot
|
||||
training, evaluation, checkpointing, and dataset interfaces for easier use with
|
||||
LeRobot datasets.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot with the MolmoAct2 optional dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[molmoact2]"
|
||||
```
|
||||
|
||||
To run the models in this repository, you need an NVIDIA GPU. The measurements
|
||||
below were taken on a single NVIDIA H100 80GB with bf16 model loading, LIBERO with two RGB cameras. MolmoAct2 rows use `chunk_size=10`, action dim 7
|
||||
padded to `expected_max_action_dim=32`, and `num_flow_timesteps=8`. Training measurements use
|
||||
`gradient_checkpointing=true` and include the forward pass, backward pass,
|
||||
gradient clipping, optimizer step, and optimizer state allocation. Values are
|
||||
peak GPU memory sampled with `nvidia-smi`. Leave a few GiB of headroom for
|
||||
dataloader workers, CUDA context, and fragmentation.
|
||||
|
||||
Multi-GPU training through `accelerate` increases throughput and global batch
|
||||
size, but this LeRobot port does not currently expose the original MolmoAct2
|
||||
`fsdp_devices` model-parallel training path. The current training script has
|
||||
not been tested for multi-node training.
|
||||
|
||||
| Mode | Peak Memory, bs=8 | Peak Memory, bs=16 | Peak Memory, bs=32 |
|
||||
| ------------------------------------------------ | ----------------: | -----------------: | -----------------: |
|
||||
| Inference, continuous, CUDA graph enabled (bs=1) | 12.1 GiB | - | - |
|
||||
| Fine-tuning, action expert only, continuous | 16.5 GiB | 18.3 GiB | 21.4 GiB |
|
||||
| Fine-tuning, LoRA VLM, both action modes | 20.2 GiB | 26.8 GiB | 41.3 GiB |
|
||||
| Fine-tuning, full model, both action modes | 48.3 GiB | 49.8 GiB | 60.1 GiB |
|
||||
|
||||
The repo has been tested with Ubuntu 22.04.
|
||||
|
||||
## Usage
|
||||
|
||||
To use MolmoAct2 in a LeRobot training config, set:
|
||||
|
||||
```python
|
||||
policy.type=molmoact2
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
MolmoAct2 can be fine-tuned from either the released MolmoAct2 Hugging Face
|
||||
checkpoint format or from a checkpoint already saved by LeRobot. Both routes use
|
||||
the same LeRobot training loop, dataset transforms, checkpoint saving, and
|
||||
logging. The difference is only how the initial policy weights and processor
|
||||
state are loaded.
|
||||
|
||||
### Training With Original MolmoAct2 Weight
|
||||
|
||||
Use `policy.checkpoint_path` when starting from a released MolmoAct2 checkpoint,
|
||||
for example `allenai/MolmoAct2` or `allenai/MolmoAct2-LIBERO`. LeRobot will load
|
||||
the original HF model files, then build its own policy processor from the
|
||||
dataset metadata and the policy options below.
|
||||
|
||||
The command below shows full fine-tuning on the merged LIBERO dataset. It uses
|
||||
bf16 model loading, 8 flow timesteps, LeRobot dataset statistics, image
|
||||
augmentation, and LeRobot's checkpointing/logging path.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.setup_type="single franka robotic arm in libero" \
|
||||
--policy.control_mode="delta end-effector pose" \
|
||||
--policy.image_keys='["observation.images.image","observation.images.wrist_image"]' \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.freeze_embedding=true \
|
||||
--policy.normalize_gripper=false \
|
||||
--policy.enable_knowledge_insulation=false \
|
||||
--policy.push_to_hub=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Training With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` when starting from a MolmoAct2 checkpoint that was saved by
|
||||
LeRobot, either from a local `pretrained_model` directory or from the Hub. This
|
||||
restores the saved LeRobot policy config, model weights, processor, and
|
||||
normalization statistics. You can still override training-time options such as
|
||||
`batch_size`, `steps`, LoRA flags, or `policy.action_mode`.
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--num_processes=8 \
|
||||
--mixed_precision=bf16 \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id=allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.root=/path/to/lerobot/data/allenai/MolmoAct2-LIBERO-Dataset \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.path=/path/to/pretrained_model \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=both \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.num_flow_timesteps=8 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--wandb.enable=true \
|
||||
--wandb.entity=<wandb_entity> \
|
||||
--wandb.project=<wandb_project> \
|
||||
--job_name=<job_name> \
|
||||
--output_dir=outputs/<job_name> \
|
||||
--steps=10000 \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
### Common Practices
|
||||
|
||||
For fine-tuning on a comparatively small dataset, such as a single LIBERO suite
|
||||
or a real-world dataset with less than 200 demonstrations, a global batch size of
|
||||
16 to 32 is a good starting point. In these settings, `policy.enable_lora_vlm=true` or `policy.train_action_expert_only=true` is also a practical choice. In both
|
||||
cases, we intentionally keep the action expert fully trainable, which we found
|
||||
to be crucial for model performance. For larger fine-tuning datasets, larger
|
||||
global batch sizes and full fine-tuning are usually preferred.
|
||||
|
||||
### Common Policy Options
|
||||
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint to initialize from.
|
||||
Use this for released MolmoAct2 weights.
|
||||
- `policy.path`: LeRobot checkpoint to initialize from. Use this for checkpoints
|
||||
created by LeRobot training.
|
||||
- `policy.action_mode`: training target, one of `continuous`, `discrete`, or
|
||||
`both`. `both` trains the flow-matching action expert and the discrete
|
||||
action-token loss.
|
||||
- `policy.train_action_expert_only`: trains only parameters whose names contain
|
||||
`action_expert`. It requires `policy.action_mode=continuous`.
|
||||
- `policy.enable_lora_vlm`: enables LoRA on VLM linear layers. Use
|
||||
`policy.enable_lora_action_expert=true` only if LoRA should also cover action
|
||||
expert linear layers. When `policy.enable_lora_action_expert=false`, the
|
||||
action expert base weights remain fully trainable while the VLM is trained
|
||||
through LoRA adapters. When `policy.enable_lora_action_expert=true`, the
|
||||
action expert is also adapter-tuned instead of fully fine-tuned.
|
||||
- `policy.enable_knowledge_insulation`: when `true`, detaches action-expert
|
||||
context K/V states before the action loss. The default is `false`.
|
||||
- `policy.chunk_size`: action horizon used by the policy. For LIBERO we use
|
||||
`10`. This LeRobot port overrides the loaded checkpoint's
|
||||
`max_action_horizon` with this value.
|
||||
- `policy.n_action_steps`: number of actions consumed from each predicted
|
||||
chunk before querying the policy again. For LIBERO, set it to `chunk_size`.
|
||||
- `policy.setup_type`: text inserted into the prompt to describe the robot and
|
||||
scene, e.g. `single franka robotic arm in libero`. More examples are listed
|
||||
in the `metadata_by_tag` entries of
|
||||
[`norm_stats.json`](https://huggingface.co/allenai/MolmoAct2/blob/main/norm_stats.json).
|
||||
- `policy.control_mode`: text inserted into the prompt to describe the action
|
||||
space, e.g. `delta end-effector pose` or `absolute joint pose`.
|
||||
- `policy.image_keys`: ordered LeRobot image observation keys passed to the
|
||||
processor.
|
||||
- `policy.model_dtype`: checkpoint/forward dtype, one of `float32`,
|
||||
`bfloat16`, or `float16`. Use `bfloat16` for normal training.
|
||||
- `policy.num_flow_timesteps`: number of flow-matching timesteps sampled per
|
||||
example during training. We use `8` for fine-tuning.
|
||||
- `policy.num_inference_steps`: optional override for continuous action
|
||||
generation steps at inference time.
|
||||
- `policy.gradient_checkpointing`: enables checkpointing in the VLM/action path
|
||||
to reduce activation memory.
|
||||
- `policy.freeze_embedding`: freezes input embeddings. The default is `true`.
|
||||
- `policy.normalize_gripper`: controls whether gripper dimensions are included
|
||||
in state/action quantile normalization. The default is `false`.
|
||||
- `policy.normalize_language`: normalizes task strings before prompt
|
||||
construction. The default is `true`.
|
||||
- `policy.mask_action_dim_padding`: masks padded dimensions in the flow loss.
|
||||
Released checkpoints use `policy.expected_max_action_dim=32`.
|
||||
- `policy.max_sequence_length`: optional manual sequence cap. Leave unset to
|
||||
infer it from images, state dimension, action dimension, action horizon, and
|
||||
discrete-action mode.
|
||||
|
||||
### Learning Rates
|
||||
|
||||
MolmoAct2 uses parameter-group learning rates to match the original MolmoAct2
|
||||
fine-tuning experiments.
|
||||
|
||||
- Full fine-tuning uses `policy.optimizer_lr=1e-5` for the VLM,
|
||||
`policy.optimizer_vit_lr=5e-6` for the vision tower,
|
||||
`policy.optimizer_connector_lr=5e-6` for image connector layers, and
|
||||
`policy.optimizer_action_expert_lr=5e-5` for the action expert.
|
||||
- LoRA VLM fine-tuning sets the VLM, vision, and connector LoRA parameter
|
||||
groups to `5e-5` when `policy.enable_lora_vlm=true`. By default,
|
||||
`policy.enable_lora_action_expert=false`, so the action expert is still fully
|
||||
fine-tuned with `policy.optimizer_action_expert_lr`. If
|
||||
`policy.enable_lora_action_expert=true`, the action expert is trained through
|
||||
LoRA adapters instead.
|
||||
- Action-expert-only fine-tuning trains only the action expert and uses
|
||||
`policy.optimizer_action_expert_lr=5e-5`.
|
||||
|
||||
You can override the full fine-tuning and action-expert learning rates with
|
||||
`policy.optimizer_lr`, `policy.optimizer_vit_lr`,
|
||||
`policy.optimizer_connector_lr`, and `policy.optimizer_action_expert_lr`.
|
||||
Scheduler settings can be changed with `policy.scheduler_warmup_steps`,
|
||||
`policy.scheduler_decay_steps`, and `policy.scheduler_decay_lr`.
|
||||
|
||||
### Dataset Quantile Statistics
|
||||
|
||||
MolmoAct2 defaults to quantile normalization for state and action features. If
|
||||
your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
Alternatively, train MolmoAct2 with mean/std normalization:
|
||||
|
||||
```bash
|
||||
--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluation also supports both LeRobot-saved checkpoints and original MolmoAct2
|
||||
HF checkpoints. For LIBERO replication, keep the EGL rendering environment
|
||||
fixed and use `policy.per_episode_seed=true`.
|
||||
|
||||
**Important:** We found that `num_steps_wait=10` does not reliably let the
|
||||
LIBERO scene stabilize and can degrade measured success. All LIBERO evaluation
|
||||
results reported here use `num_steps_wait=50`.
|
||||
|
||||
### Evaluation With LeRobot MolmoAct2 Weight
|
||||
|
||||
Use `policy.path` for a checkpoint saved by LeRobot. The saved processor and
|
||||
normalization statistics are restored together with the model.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.path=allenai/MolmoAct2-LIBERO-LeRobot \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=bfloat16 \
|
||||
--policy.use_amp=true \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10,libero_goal,libero_object,libero_spatial \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
### Evaluation With Original MolmoAct2 Weight
|
||||
|
||||
You can evaluate a released Hugging Face checkpoint directly without first
|
||||
converting it to a LeRobot checkpoint. In this case, set
|
||||
`policy.checkpoint_path` to the HF model repo and provide `policy.norm_tag`.
|
||||
For LIBERO, `policy.norm_tag=libero` loads the LIBERO action/state
|
||||
normalization statistics, action horizon, prompt metadata, and image-key order
|
||||
from the checkpoint's `norm_stats.json`.
|
||||
|
||||
To fully replicate the MolmoAct2 paper results with released Hugging Face
|
||||
checkpoints, we recommend using the v0.5.1-pinned
|
||||
[`allenai/lerobot` `molmoact2-hf-inference`](https://github.com/allenai/lerobot/tree/molmoact2-hf-inference)
|
||||
branch. That branch matches the original evaluation settings used for the
|
||||
reported numbers.
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl
|
||||
export PYOPENGL_PLATFORM=egl
|
||||
export OMP_NUM_THREADS=1
|
||||
export MKL_NUM_THREADS=1
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=molmoact2 \
|
||||
--policy.checkpoint_path=allenai/MolmoAct2-LIBERO \
|
||||
--policy.norm_tag=libero \
|
||||
--policy.inference_action_mode=continuous \
|
||||
--policy.model_dtype=float32 \
|
||||
--policy.use_amp=false \
|
||||
--policy.enable_inference_cuda_graph=true \
|
||||
--policy.device=cuda \
|
||||
--policy.per_episode_seed=true \
|
||||
--policy.eval_seed=1000 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_goal \
|
||||
--env.camera_name_mapping='{"agentview_image":"image","robot0_eye_in_hand_image":"wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=1000
|
||||
```
|
||||
|
||||
Use `--env.task=libero_10,libero_goal,libero_object,libero_spatial` to run the
|
||||
full LIBERO suite. The same command works for other released MolmoAct2
|
||||
checkpoints as long as the requested `policy.norm_tag` exists in that
|
||||
checkpoint's `norm_stats.json`.
|
||||
|
||||
### Common Evaluation Options
|
||||
|
||||
- `policy.inference_action_mode`: required for rollout. Use `continuous` for
|
||||
flow-matching inference or `discrete` for action-token inference. It must be
|
||||
compatible with the training-time `policy.action_mode` saved in the
|
||||
checkpoint.
|
||||
- `policy.path`: LeRobot checkpoint path or Hub repo. Use this for checkpoints
|
||||
saved by LeRobot.
|
||||
- `policy.checkpoint_path`: original MolmoAct2 HF checkpoint path or Hub repo.
|
||||
Use this with `policy.type=molmoact2` and `policy.norm_tag`.
|
||||
- `policy.norm_tag`: selects normalization statistics, prompt metadata,
|
||||
image-key order, and action horizon from the original checkpoint's
|
||||
`norm_stats.json`. It is required for direct original-HF checkpoint
|
||||
evaluation.
|
||||
- `policy.model_dtype`: model load/forward dtype. Use `bfloat16` for normal
|
||||
GPU evaluation. Use `float32` only when you explicitly want fp32 inference.
|
||||
- `policy.use_amp`: runs the policy forward under autocast during eval. For
|
||||
`model_dtype=bfloat16`, keep this enabled.
|
||||
- `policy.enable_inference_cuda_graph`: enables the MolmoAct2 inference CUDA
|
||||
graph path for faster repeated continuous-action rollout.
|
||||
- `policy.per_episode_seed` and `policy.eval_seed`: make stochastic continuous
|
||||
action generation deterministic per episode for replication.
|
||||
- `env.task`: comma-separated LIBERO suites or a single suite. Use
|
||||
`libero_10,libero_goal,libero_object,libero_spatial` for the full benchmark.
|
||||
- `env.camera_name_mapping`: maps LIBERO camera names to the image keys expected
|
||||
by the policy processor.
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
MolmoAct2 has demonstrated strong performance on the LIBERO benchmark suite. To
|
||||
compare and test its LeRobot implementation, we fine-tuned
|
||||
[`allenai/MolmoAct2-LIBERO`](https://huggingface.co/allenai/MolmoAct2-LIBERO)
|
||||
for an additional 10k steps on the LIBERO dataset with per-GPU batch size 32 on
|
||||
8 H100 GPUs, then compared the results to the original MolmoAct2 reference
|
||||
results.
|
||||
|
||||
The LeRobot fine-tuned checkpoint reported here is available at
|
||||
[`allenai/MolmoAct2-LIBERO-LeRobot`](https://huggingface.co/allenai/MolmoAct2-LIBERO-LeRobot)
|
||||
and was trained on
|
||||
[`allenai/MolmoAct2-LIBERO-Dataset`](https://huggingface.co/datasets/allenai/MolmoAct2-LIBERO-Dataset).
|
||||
|
||||
| Benchmark | LeRobot Implementation | MolmoAct2 Original |
|
||||
| -------------- | ---------------------: | -----------------: |
|
||||
| LIBERO Spatial | 98.4% | 97.8% |
|
||||
| LIBERO Object | 100.0% | 100.0% |
|
||||
| LIBERO Goal | 98.0% | 97.8% |
|
||||
| LIBERO 10 | 96.6% | 93.2% |
|
||||
| Average | 98.25% | 97.20% |
|
||||
|
||||
These results demonstrate MolmoAct2's strong performance across diverse robotic
|
||||
manipulation tasks. To reproduce them, follow the instructions in the LIBERO
|
||||
evaluation section.
|
||||
|
||||
## Differences From the Original Implementation
|
||||
|
||||
This LeRobot port is intended to match MolmoAct2 behavior while using LeRobot's
|
||||
dataset, training, evaluation, checkpoint, and logging infrastructure. The main
|
||||
differences from the original training repository are:
|
||||
|
||||
- The original paper training stack loads the model in fp32 and trains under
|
||||
mixed precision. This LeRobot port usually loads the checkpoint directly in
|
||||
`policy.model_dtype=bfloat16` for lower memory use.
|
||||
- The original repository uses its own FSDP/model-parallel training path. The
|
||||
LeRobot port uses the standard LeRobot/Accelerate training path and has not
|
||||
been tested for multi-node training.
|
||||
- The original repository supports sequence packing. The LeRobot port trains on
|
||||
one LeRobot sample per item and pads to an inferred fixed sequence budget.
|
||||
- The LeRobot port follows LeRobot's optimizer, scheduler, checkpoint saving,
|
||||
dataset transforms, image augmentation, and Weights & Biases logging
|
||||
conventions.
|
||||
- The original training path supports mixed action horizons by padding to
|
||||
`max_action_horizon` and masking padded horizon slots in the action expert
|
||||
self-attention. This is useful when training across datasets with different
|
||||
control frequencies. The LeRobot port currently targets single-dataset
|
||||
fine-tuning, so `policy.chunk_size` overrides the checkpoint
|
||||
`max_action_horizon` and horizon masking is not implemented yet. Support for
|
||||
this mixed-horizon path is planned.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
@@ -28,13 +28,15 @@ lerobot-train \
|
||||
--steps=100000 \
|
||||
--batch_size=32 \
|
||||
--peft.method_type=LORA \
|
||||
--peft.r=64
|
||||
--peft.r=64 \
|
||||
--peft.lora_alpha=64
|
||||
```
|
||||
|
||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
|
||||
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
|
||||
the closer you get to full fine-tuning
|
||||
|
||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||
|
||||
@@ -91,7 +91,7 @@ lerobot-train \
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control. The LeRobot
|
||||
integration uses an InternVL3 vision-language backbone with a flow-matching
|
||||
action head, and supports staged training through the standard LeRobot policy
|
||||
APIs.
|
||||
|
||||
The upstream EVO1 project is available at
|
||||
[MINT-SJTU/Evo-1](https://github.com/MINT-SJTU/Evo-1).
|
||||
|
||||
```bibtex
|
||||
@misc{evo1,
|
||||
title = {EVO1},
|
||||
author = {{MINT-SJTU}},
|
||||
year = {2026},
|
||||
howpublished = {\url{https://github.com/MINT-SJTU/Evo-1}},
|
||||
}
|
||||
```
|
||||
@@ -1,6 +1,13 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -24,4 +31,103 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
Hugging Face Models:
|
||||
|
||||
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
|
||||
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
|
||||
is itself a defaulted dataclass, so the original config dataclasses fail to import
|
||||
(`non-default argument follows default argument`). The two implementations therefore
|
||||
**cannot be imported in the same Python process**.
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
```bash
|
||||
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
|
||||
CKPT=$(python - <<'PY'
|
||||
import os
|
||||
from huggingface_hub import snapshot_download
|
||||
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
|
||||
allow_patterns=["libero_10/*"]), "libero_10"))
|
||||
PY
|
||||
)
|
||||
|
||||
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
|
||||
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
|
||||
tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
|
||||
|
||||
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
|
||||
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# MolmoAct2
|
||||
|
||||
This repository contains the LeRobot policy implementation of
|
||||
[MolmoAct2](https://allenai.org/blog/molmoact2), ported into LeRobot for
|
||||
training, evaluation, checkpointing, and dataset compatibility.
|
||||
|
||||
This implementation currently supports training and evaluation for the regular
|
||||
MolmoAct2 model. MolmoAct2-Think, which supports adaptive depth reasoning, is
|
||||
not included in this LeRobot policy yet and is coming soon.
|
||||
|
||||
For the original MolmoAct2 training code used for the experiments reported in
|
||||
the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
|
||||
## LIBERO Evaluation
|
||||
|
||||
Important: we found that `num_steps_wait=10` does not reliably let the LIBERO
|
||||
scene stabilize and can degrade measured success. All LIBERO evaluation results
|
||||
reported for this LeRobot implementation use `num_steps_wait=50`.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{fang2026molmoact2actionreasoningmodels,
|
||||
title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
|
||||
author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
|
||||
year={2026},
|
||||
eprint={2605.02881},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2605.02881},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model is licensed under Apache 2.0. It is intended for research and
|
||||
educational use in accordance with
|
||||
[Ai2's Responsible Use Guidelines](https://allenai.org/responsible-use),
|
||||
consistent with [allenai/molmoact2](https://github.com/allenai/molmoact2).
|
||||
@@ -0,0 +1,39 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
# reBot B601-DM
|
||||
|
||||
[reBot B601-DM](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/) is an open-source, low-cost robot arm from Seeed Studio for embodied-AI and imitation learning. It comes as a **follower** arm (the `B601-DM`, a 6-DOF arm plus gripper driven by Damiao CAN motors) and a **leader** arm (the `StarArm102` / `reBot Arm 102`, driven by FashionStar UART smart servos) used to teleoperate it.
|
||||
|
||||
This page covers **calibration** and **teleoperation** for both single-arm and bimanual (dual-arm) setups.
|
||||
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/b601dm_zeroposition.jpg"
|
||||
alt="reBot B601-DM follower arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
<img
|
||||
src="https://files.seeedstudio.com/wiki/robotics/projects/lerobot/102_zeroposition.jpg"
|
||||
alt="reBot Arm 102 leader arm at its zero position"
|
||||
width="48%"
|
||||
/>
|
||||
</div>
|
||||
|
||||
_Left: the B601-DM follower at its zero position. Right: the reBot Arm 102 leader at its zero position. Images courtesy of [Seeed Studio](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/)._
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the reBot support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[rebot]"
|
||||
```
|
||||
|
||||
This pulls in `motorbridge` (CAN motor control for the B601-DM follower) and `motorbridge-smart-servo` (FashionStar UART servos for the reBot Arm 102 leader).
|
||||
|
||||
## Registered device types
|
||||
|
||||
| Type | Kind |
|
||||
| ------------------------ | -------------------------------------------- |
|
||||
| `rebot_b601_follower` | single-arm B601-DM follower robot |
|
||||
| `bi_rebot_b601_follower` | bimanual (dual-arm) follower robot |
|
||||
| `rebot_102_leader` | single-arm reBot Arm 102 leader teleoperator |
|
||||
| `bi_rebot_102_leader` | bimanual (dual-arm) leader teleoperator |
|
||||
|
||||
The bimanual types compose two single-arm instances and namespace each arm's
|
||||
observation/action keys with a `left_` / `right_` prefix. Per-arm settings are
|
||||
passed through nested `left_arm_config.*` / `right_arm_config.*` arguments.
|
||||
|
||||
## Find the USB ports
|
||||
|
||||
For each device, find the USB port associated with its motor bus using:
|
||||
|
||||
```bash
|
||||
lerobot-find-port
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
On Linux, remove `brltty` (`sudo apt remove brltty`) so it does not hold the
|
||||
leader's USB serial port. You may also need to grant access to the serial
|
||||
devices: `sudo chmod 666 /dev/ttyACM* /dev/ttyUSB*`.
|
||||
</Tip>
|
||||
|
||||
## Calibration
|
||||
|
||||
Neither arm stores a persistent hardware calibration: every time it connects, the motors are re-zeroed against the pose the arm is physically holding. Calibration simply records that zero pose. When prompted, **manually move the arm to its zero position** (the default sit-down pose shown above, gripper fully closed) and press <kbd>ENTER</kbd>.
|
||||
|
||||
### Follower (B601-DM)
|
||||
|
||||
<hfoptions id="calibrate-follower">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
Connect the bimanual follower; calibration runs for the left arm, then the right arm.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao
|
||||
```
|
||||
|
||||
Per-arm calibration files are saved with `_left` / `_right` suffixes on the id.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader (reBot Arm 102)
|
||||
|
||||
<hfoptions id="calibrate-leader">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperation
|
||||
|
||||
Once both arms are calibrated, drive the follower with the leader. The follower talks to its CAN bus through a Damiao serial bridge (`can_adapter=damiao`, the default) or a SocketCAN adapter (`can_adapter=socketcan`). See the [OpenArm page](./openarm) for more details on the SocketCAN adapter configuration.
|
||||
|
||||
<hfoptions id="teleoperate">
|
||||
<hfoption id="Single arm">
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.id=follower \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Dual arm">
|
||||
|
||||
The bimanual leader and follower reuse the single-arm classes; each arm is
|
||||
configured through nested `left_arm_config.*` / `right_arm_config.*` arguments,
|
||||
so a bimanual reBot Arm 102 leader drives a bimanual B601-DM follower.
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_rebot_b601_follower \
|
||||
--robot.id=bi_follower \
|
||||
--robot.left_arm_config.port=/dev/ttyACM0 \
|
||||
--robot.left_arm_config.can_adapter=damiao \
|
||||
--robot.right_arm_config.port=/dev/ttyACM1 \
|
||||
--robot.right_arm_config.can_adapter=damiao \
|
||||
--teleop.type=bi_rebot_102_leader \
|
||||
--teleop.id=bi_leader \
|
||||
--teleop.left_arm_config.port=/dev/ttyUSB0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyUSB1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip>
|
||||
The leader and follower share the same joint names (`shoulder_pan,
|
||||
shoulder_lift, elbow_flex, wrist_flex, wrist_yaw, wrist_roll, gripper`), so
|
||||
leader actions map directly onto the follower.
|
||||
</Tip>
|
||||
|
||||
If the motion of a joint is reversed, flip its sign in the leader's `joint_directions` (the gripper also carries a scale to widen its range to the follower):
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=rebot_b601_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.can_adapter=damiao \
|
||||
--teleop.type=rebot_102_leader \
|
||||
--teleop.port=/dev/ttyUSB0 \
|
||||
--teleop.joint_directions='{"shoulder_pan":-1,"shoulder_lift":-1,"elbow_flex":1,"wrist_flex":1,"wrist_yaw":1,"wrist_roll":-1,"gripper":-6}'
|
||||
```
|
||||
|
||||
## Recording datasets
|
||||
|
||||
Swap `lerobot-teleoperate` for `lerobot-record` (with the same `--robot.*` / `--teleop.*` arguments, plus `--dataset.*`) to record demonstrations for training. See [Imitation Learning for Robots](./il_robots) for the full workflow.
|
||||
|
||||
For hardware assembly and wiring, see the [Seeed Studio reBot wiki](https://wiki.seeedstudio.com/rebot_arm_b601_dm_lerobot/).
|
||||
@@ -0,0 +1,185 @@
|
||||
# ROBOMETER
|
||||
|
||||
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
|
||||
|
||||
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
|
||||
**Project**: [robometer.github.io](https://robometer.github.io/)
|
||||
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
|
||||
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
|
||||
|
||||
## Overview
|
||||
|
||||
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
|
||||
|
||||
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
|
||||
- **Success head**: predicts per-frame task success probability.
|
||||
- **Preference head**: predicts which of two trajectories better completes the task during training.
|
||||
|
||||
The paper trains ROBOMETER with a composite objective:
|
||||
|
||||
```text
|
||||
L = L_pref + L_prog + L_succ
|
||||
```
|
||||
|
||||
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
|
||||
|
||||
## What the LeRobot Integration Covers
|
||||
|
||||
- Standard `reward_model.type=robometer` configuration through LeRobot.
|
||||
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
|
||||
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
|
||||
- Dense, frame-level progress and success predictions internally.
|
||||
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
|
||||
|
||||
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the ROBOMETER dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[robometer]"
|
||||
```
|
||||
|
||||
If you use `uv` directly from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra robometer
|
||||
```
|
||||
|
||||
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
ROBOMETER expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets, the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | ------------------------ | ----------------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
|
||||
|
||||
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
|
||||
|
||||
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
|
||||
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the Reward Model Directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||
|
||||
cfg = RobometerConfig(
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
reward_output="progress",
|
||||
)
|
||||
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
|
||||
```
|
||||
|
||||
### Encode Frames and Compute a Reward
|
||||
|
||||
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
|
||||
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||
# task: str
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=cfg.base_model_id,
|
||||
use_multi_image=cfg.use_multi_image,
|
||||
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
|
||||
max_frames=cfg.max_frames,
|
||||
)
|
||||
|
||||
encoded = encoder.encode_samples([(frames, task)])
|
||||
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
|
||||
|
||||
reward = reward_model.compute_reward(batch)
|
||||
```
|
||||
|
||||
`reward` is a tensor of shape `(batch_size,)`.
|
||||
|
||||
### Use the Reward Factory
|
||||
|
||||
You can also instantiate ROBOMETER through the reward factory:
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"robometer",
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Backbone and Vocabulary
|
||||
|
||||
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
|
||||
|
||||
```text
|
||||
<|split_token|>
|
||||
<|reward_token|>
|
||||
<|pref_token|>
|
||||
<|sim_token|>
|
||||
<|prog_token|>
|
||||
```
|
||||
|
||||
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
|
||||
|
||||
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
|
||||
|
||||
### Progress Prediction
|
||||
|
||||
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
|
||||
|
||||
### Success Prediction
|
||||
|
||||
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
|
||||
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
|
||||
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
|
||||
|
||||
## References
|
||||
|
||||
- [ROBOMETER project](https://robometer.github.io/)
|
||||
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
|
||||
- [Original ROBOMETER code](https://github.com/robometer/robometer)
|
||||
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
|
||||
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{liang2026robometer,
|
||||
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
|
||||
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
|
||||
year={2026},
|
||||
booktitle={Robotics: Science and Systems 2026},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
|
||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
||||
Once you are logged in, you can run inference in your setup by doing:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||
# <- RTC optional, use when running on low power hardware \
|
||||
# --inference.type=rtc \
|
||||
# --inference.rtc.execution_horizon=10 \
|
||||
# --inference.rtc.max_guidance_weight=10.0 \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_red_leader_arm \
|
||||
# --display_data=true #optional use if you want to see the camera stream \
|
||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||
```
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `vcodec` | `--dataset.camera_encoder.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `30` | Max buffered frames per camera (~1s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -48,7 +48,7 @@ This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 30 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
@@ -82,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | --------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -100,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -146,7 +146,7 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.camera_encoder.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -0,0 +1,177 @@
|
||||
# TOPReward
|
||||
|
||||
TOPReward is a **zero-shot reward model** that extracts token log-probabilities from an off-the-shelf vision-language model (VLM) as a robotic reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood that the instruction is true — no fine-tuning required.
|
||||
|
||||
**Paper**: [TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics](https://arxiv.org/abs/2602.19313)
|
||||
**Project**: [topreward.github.io](https://topreward.github.io/webpage/)
|
||||
**Original code**: [github.com/TOPReward/TOPReward](https://github.com/TOPReward/TOPReward)
|
||||
**Default backbone**: [Qwen/Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Overview
|
||||
|
||||
TOPReward asks a generic VLM how likely a task instruction is, **conditioned on the video** of a robot trying to complete that task. Concretely, given:
|
||||
|
||||
- A trajectory video (a sequence of frames).
|
||||
- A task instruction (e.g. _"open the drawer"_).
|
||||
|
||||
it builds a chat prompt of the form
|
||||
|
||||
```text
|
||||
<video>
|
||||
"The above video shows a robot manipulation trajectory that completes the
|
||||
following task: <instruction> Decide whether the above statement is True
|
||||
or not. The answer is: True"
|
||||
```
|
||||
|
||||
forwards it through the VLM, label-masks everything except the very last token, and reads back the log-probability of that token — by default the literal `"True"` that closes the suffix template. The resulting `log P("True" | video + prompt + instruction)` is the reward.
|
||||
|
||||
Because the method only depends on a frozen VLM, TOPReward is **zero-shot**: there are no fine-tuned weights to host. The "model" in LeRobot is a small wrapper around `transformers`' `Qwen3VLForConditionalGeneration` plus the label-masking logic. The processor owns the tokeniser and builds the full chat prompt (EO-1/Robometer pattern).
|
||||
|
||||
## What the LeRobot integration covers
|
||||
|
||||
- Standard `reward_model.type=topreward` configuration through LeRobot.
|
||||
- VLM loading via the `transformers` `Qwen3VLForConditionalGeneration` API.
|
||||
- Prompt assembly + tokenisation in the processor (matching upstream `QwenClient.compute_instruction_reward`).
|
||||
- `compute_reward()` returns one scalar log-prob per sample.
|
||||
- LeRobot reward-model save/load — `save_pretrained` writes only `config.json` (the VLM is identified by `vlm_name`).
|
||||
- An offline labeling script that writes a `topreward_progress.parquet` (SARM-compatible schema) for RA-BC and overlay.
|
||||
|
||||
The current LeRobot port supports the **Qwen3-VL client only**. Other upstream clients (Gemini, OpenAI, Gemma, Molmo) can be added as follow-up extras.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot following the [Installation Guide](./installation).
|
||||
2. Install the TOPReward optional extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[topreward]"
|
||||
```
|
||||
|
||||
or, with `uv` from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra topreward
|
||||
```
|
||||
|
||||
This pulls in `transformers`. The first time you run TOPReward, Hugging Face will also download the VLM weights from the Hub (~16 GB for Qwen3-VL-8B-Instruct). A GPU is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
TOPReward expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | --------------------------- | --------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera observation used by TOPReward |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data for the task string |
|
||||
| `reward_model.max_frames` | `16` | Cap on frames per sample |
|
||||
| `reward_model.fps` | `2.0` | Metadata passed to the Qwen video processor |
|
||||
| `reward_model.vlm_name` | `Qwen/Qwen3-VL-8B-Instruct` | Hugging Face Hub id of the underlying VLM |
|
||||
|
||||
The model returns:
|
||||
|
||||
- `compute_reward(batch)`: one log-probability per sample. Higher = better task-video alignment. When `success_threshold` is finite, returns the binary thresholded value instead.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the reward model directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.topreward import TOPRewardConfig, TOPRewardModel
|
||||
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
)
|
||||
reward_model = TOPRewardModel(cfg)
|
||||
```
|
||||
|
||||
### Use the reward factory
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"topreward",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor tokenises the full prompt (video + prefix + instruction suffix), writes Qwen-VL tensors + `prompt_length` under `observation.topreward.*`. The model reads those tensors, label-masks based on `prompt_length`, and extracts the log-prob reward.
|
||||
|
||||
### Offline dataset labeling
|
||||
|
||||
Write a `topreward_progress.parquet` for RA-BC training and overlay videos:
|
||||
|
||||
```bash
|
||||
# Sparse-dense (15 anchors per episode, matches upstream)
|
||||
uv run python -m lerobot.rewards.topreward.compute_rabc_weights \
|
||||
--dataset-repo-id lerobot/libero_10_image \
|
||||
--num-samples 15 \
|
||||
--device cuda
|
||||
```
|
||||
|
||||
Then render the progress overlay for any episode:
|
||||
|
||||
```bash
|
||||
uv run examples/dataset/create_progress_videos.py \
|
||||
--repo-id lerobot/libero_10_image \
|
||||
--episode 0 \
|
||||
--progress-file topreward_progress.parquet \
|
||||
--gif
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Prompt knobs
|
||||
|
||||
The default prompt mirrors the upstream paper:
|
||||
|
||||
```text
|
||||
prompt_prefix = "The above video shows a robot manipulation trajectory that completes the following task: "
|
||||
prompt_suffix_template = "{instruction} Decide whether the above statement is True or not. The answer is: True"
|
||||
```
|
||||
|
||||
Both are exposed on `TOPRewardConfig` for ablation. The suffix template **must** contain `{instruction}`.
|
||||
|
||||
### Chat template
|
||||
|
||||
`add_chat_template=True` wraps the full prompt (including instruction) with the tokenizer's chat template before tokenisation. Default is `False`, matching the upstream paper's main experiments.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot port is **inference-only and zero-shot**; `forward()` is not overridden and `is_trainable` returns `False`.
|
||||
- Only the **Qwen3-VL family** is supported; other upstream clients are out of scope.
|
||||
- TOPReward inherits the underlying VLM's biases.
|
||||
|
||||
## References
|
||||
|
||||
- [TOPReward project page](https://topreward.github.io/webpage/)
|
||||
- [TOPReward paper](https://arxiv.org/abs/2602.19313)
|
||||
- [Original TOPReward code](https://github.com/TOPReward/TOPReward)
|
||||
- [Qwen3-VL-8B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2026topreward,
|
||||
title={TOPReward: Token Probabilities as Hidden Zero-Shot Rewards for Robotics},
|
||||
author={Chen, Shirui and Harrison, Cole and Lee, Ying-Chun and Yang, Angela Jin and
|
||||
Ren, Zhongzheng and Ratliff, Lillian J and Duan, Jiafei and Fox, Dieter and
|
||||
Krishna, Ranjay},
|
||||
journal={arXiv preprint arXiv:2602.19313},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The original TOPReward codebase is MIT-licensed. The LeRobot port follows the LeRobot Apache 2.0 license; the wrapped Qwen3-VL weights are subject to the original Qwen license.
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
--operation.camera_encoder.vcodec libsvtav1 \
|
||||
--operation.camera_encoder.pix_fmt yuv420p \
|
||||
--operation.camera_encoder.g 2 \
|
||||
--operation.camera_encoder.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,11 +147,7 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `camera_encoder`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder.<field>. See [Video Encoding Parameters](./video_encoding_parameters) for more details.
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
# Video encoding parameters
|
||||
|
||||
When video storage is enabled, LeRobot stores each camera stream as an **MP4** file instead of saving one image file per timestep. Video encoding compresses across time, which usually cuts dataset size and I/O compared to a pile of PNG, while keeping MP4 — a format every player and loader understands.
|
||||
|
||||
Encoding frames into an MP4 is a full FFmpeg pipeline: choice of encoder, pixel format, GOP/keyframes, quality vs. speed, and optional extra encoder flags. Most of these knobs are user-tunable through `camera_encoder`, a nested `VideoEncoderConfig` (`lerobot.configs.video.VideoEncoderConfig`) passed through PyAV.
|
||||
|
||||
You can set these parameters from the CLI with `--dataset.camera_encoder.<field>` (e.g. with `lerobot-record` or `lerobot-rollout`). The same block applies to every camera video stream in that run.
|
||||
|
||||
<Tip>
|
||||
Video storage must be on for `camera_encoder` to have any effect —
|
||||
`use_videos=True` in Python APIs, or `--dataset.video=true` on the CLI (the
|
||||
recording default). With video off, inputs stay as images and `camera_encoder`
|
||||
is ignored.
|
||||
</Tip>
|
||||
|
||||
For details on **when** frames are written vs. encoded (streaming vs. post-episode), queues, and other top-level `--dataset.*` switches, see [Streaming Video Encoding](./streaming_video_encoding). For an encoding-parameter comparison and experiments, see the [video-benchmark Space](https://huggingface.co/spaces/lerobot/video-benchmark).
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--dataset.camera_encoder.vcodec=h264 \
|
||||
--dataset.camera_encoder.preset=fast \
|
||||
--dataset.camera_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tuning parameters
|
||||
|
||||
<Tip warning={true}>
|
||||
The defaults are tuned to balance **compression ratio**, **visual quality**, and **decoding/seek speed** for typical robotics datasets. Changing them can affect both recording (CPU load, frame drops) and training (decoding throughput, image quality).
|
||||
|
||||
Only override these parameters if you have a specific reason to, and measure the impact on your pipeline before relying on the new settings.
|
||||
|
||||
</Tip>
|
||||
|
||||
All flags below are prefixed with `--dataset.camera_encoder.` on the CLI.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| --------------- | ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vcodec` | `str` | `"libsvtav1"` | Video codec name. `"auto"` picks the first available hardware encoder from a fixed preference list, falling back to `libsvtav1`. |
|
||||
| `pix_fmt` | `str` | `"yuv420p"` | Output pixel format. Must be supported by the chosen codec in your FFmpeg build. |
|
||||
| `g` | `int` | `2` | GOP size — a keyframe every `g` frames. Emitted as FFmpeg option `g`. |
|
||||
| `crf` | `int` or `float` | `30` | Abstract quality value, mapped per codec (see the [mapping](#mapping-videoencoderconfig--ffmpeg-options) below). Lower → higher quality / larger output where the mapping is monotone. |
|
||||
| `preset` | `int` or `str` | `12` \* | Encoder speed preset; meaning depends on the codec. <br/>\* When unset and `vcodec=libsvtav1`, LeRobot defaults to `12`. |
|
||||
| `fast_decode` | `int` | `0` | `libsvtav1`: `0–2`, passed via `svtav1-params`. <br/>`h264` / `hevc` (software): if `>0`, sets `tune=fastdecode`. <br/>Other codecs: usually unused. |
|
||||
| `video_backend` | `str` | `"pyav"` | Only `"pyav"` is currently implemented for video encoding. |
|
||||
| `extra_options` | `dict` | `{}` | Extra FFmpeg or codec specific options merged after the structured fields above. Cannot override keys already set by those fields. |
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
After the first episode of a video stream is encoded, the encoder configuration is **persisted into the dataset metadata** (`meta/info.json`) under each video feature, alongside the values probed from the file itself. For a video feature `observation.images.<camera>`, the layout in `info.json` is:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": {
|
||||
"observation.images.laptop": {
|
||||
"dtype": "video",
|
||||
"shape": [480, 640, 3],
|
||||
"info": {
|
||||
"video.height": 480,
|
||||
"video.width": 640,
|
||||
"video.codec": "h264",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
"video.fast_decode": 0,
|
||||
"video.video_backend": "pyav",
|
||||
"video.extra_options": { "tune": "film", "profile:v": "high", "bf": 2 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
This block is populated **once**, from the **first** episode. It assumes every
|
||||
episode in the dataset was encoded with the same `camera_encoder`. Changing
|
||||
encoder settings partway through a recording is not supported — the
|
||||
`info.json` will only reflect the parameters used for the first episode.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## Merging datasets
|
||||
|
||||
When aggregating datasets with `merge_datasets`, video files are concatenated as-is (no re-encoding), and encoder fields in `info.json` are merged per-key:
|
||||
|
||||
- **Stream-derived fields must match** across sources: `video.codec`, `video.pix_fmt`, `video.height`, `video.width`, `video.fps`. Otherwise FFmpeg's concat demuxer fails.
|
||||
- **Encoder-tuning fields are merged loosely**: `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.extra_options`. If every source agrees, the value is kept; if not, it's set to `null` (or `{}` for `video.extra_options`) and a warning is logged.
|
||||
@@ -0,0 +1,235 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||
|
||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||
of the source video, draws a progress line on each frame, and writes the result.
|
||||
The progress data is read from a parquet file that lives alongside the dataset
|
||||
(configurable via ``--progress-file``).
|
||||
|
||||
Usage:
|
||||
python examples/dataset/create_progress_videos.py \
|
||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the metadata and sarm_progress files for a dataset.
|
||||
def download_episode_metadata(
|
||||
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> Path:
|
||||
"""Download only the metadata and per-frame progress file for a dataset.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID.
|
||||
episode: Episode index (used for logging only; all meta is fetched).
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Local cache path for the downloaded snapshot.
|
||||
"""
|
||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||
local_path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
allow_patterns=["meta/**", progress_file],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for an episode.
|
||||
def load_progress_data(
|
||||
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||
) -> np.ndarray | None:
|
||||
"""Load per-frame progress values for an episode.
|
||||
|
||||
Args:
|
||||
local_path: Dataset cache root.
|
||||
episode: Episode index.
|
||||
progress_file: Filename of the per-frame progress parquet.
|
||||
|
||||
Returns:
|
||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||
"""
|
||||
parquet_path = local_path / "sarm_progress.parquet"
|
||||
parquet_path = local_path / progress_file
|
||||
if not parquet_path.exists():
|
||||
logging.warning("sarm_progress.parquet not found")
|
||||
logging.warning("%s not found", progress_file)
|
||||
return None
|
||||
df = pd.read_parquet(parquet_path)
|
||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||
episode_df = df[df["episode_index"] == episode].copy()
|
||||
if episode_df.empty:
|
||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||
return None
|
||||
episode_df = episode_df.sort_values("frame_index")
|
||||
|
||||
@@ -576,6 +585,7 @@ def process_dataset(
|
||||
camera_key: str | None,
|
||||
output_dir: Path,
|
||||
create_gif: bool = False,
|
||||
progress_file: str = "sarm_progress.parquet",
|
||||
) -> Path | None:
|
||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||
|
||||
@@ -585,6 +595,8 @@ def process_dataset(
|
||||
camera_key: Camera key to use, or None for auto-selection.
|
||||
output_dir: Directory to write output files.
|
||||
create_gif: If True, also generate a GIF from the MP4.
|
||||
progress_file: Filename of the per-frame progress parquet inside the
|
||||
dataset repo.
|
||||
|
||||
Returns:
|
||||
Path to the final output file, or None on failure.
|
||||
@@ -592,7 +604,7 @@ def process_dataset(
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||
|
||||
local_path = download_episode_metadata(repo_id, episode)
|
||||
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||
logging.info(" Local cache: %s", local_path)
|
||||
|
||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||
@@ -600,9 +612,9 @@ def process_dataset(
|
||||
|
||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||
|
||||
progress_data = load_progress_data(local_path, episode)
|
||||
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||
if progress_data is None:
|
||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||
return None
|
||||
|
||||
logging.info(" Progress frames: %d", len(progress_data))
|
||||
@@ -627,7 +639,7 @@ def process_dataset(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
@@ -658,6 +670,15 @@ def main() -> None:
|
||||
action="store_true",
|
||||
help="Also generate a GIF from the MP4 output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-file",
|
||||
type=str,
|
||||
default="sarm_progress.parquet",
|
||||
help=(
|
||||
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||
"(default: 'sarm_progress.parquet')."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
@@ -670,6 +691,7 @@ def main() -> None:
|
||||
camera_key=args.camera_key,
|
||||
output_dir=args.output_dir,
|
||||
create_gif=args.gif,
|
||||
progress_file=args.progress_file,
|
||||
)
|
||||
|
||||
if result:
|
||||
|
||||
@@ -80,7 +80,7 @@
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"HF_USER = \"your_hf_username\" # `hf auth whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
@@ -291,7 +291,34 @@
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details.\n",
|
||||
"\n",
|
||||
"Recently ```lerobot-rollout``` was introduced, you can [read more about it here](https://huggingface.co/docs/lerobot/main/en/il_robots?eval=Base+mode+%28no+recording%29#run-inference-and-evaluate-your-policy)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-rollout\",\n",
|
||||
" \"--strategy.type=base\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f'--task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--duration=60\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"if you are using the V0.5.1 release you should use ```lerobot-record``` instead of rollout"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -113,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -261,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
+35
-8
@@ -95,11 +95,22 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.12.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
|
||||
|
||||
# NOTE: torchcodec wheel availability matrix (PyPI):
|
||||
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
|
||||
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
|
||||
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
|
||||
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
|
||||
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
|
||||
|
||||
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
|
||||
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
|
||||
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -127,7 +138,9 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
# Common
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
# NOTE: 0.9.16 links against liburdfdom_sensor.so.4, which is unavailable on Ubuntu 24.04
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
@@ -140,6 +153,8 @@ pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
motorbridge-dep = ["motorbridge>=0.3.2,<0.4.0"]
|
||||
motorbridge-smart-servo-dep = ["motorbridge-smart-servo>=0.0.4,<0.1.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
@@ -163,6 +178,9 @@ unitree_g1 = [
|
||||
"lerobot[pygame-dep]",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
|
||||
# leader (motorbridge-smart-servo / FashionStar UART servos).
|
||||
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||
@@ -180,6 +198,7 @@ wallx = [
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
@@ -193,10 +212,12 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]", "timm>=1.0.0,<1.1.0"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -250,17 +271,19 @@ all = [
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[openarms]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[rebot]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[diffusion]",
|
||||
"lerobot[multi_task_dit]",
|
||||
"lerobot[wallx]",
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[evo1]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -271,6 +294,8 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
@@ -350,7 +375,6 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
@@ -387,8 +411,11 @@ default.extend-ignore-identifiers-re = [
|
||||
"ein",
|
||||
"thw",
|
||||
"inpt",
|
||||
"arange",
|
||||
"is_compileable",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE"
|
||||
"OT_VALUE",
|
||||
"VanderBilt"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -199,12 +199,13 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
self._validate_fourcc()
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
set_fourcc_after_size_and_fps = platform.system() == "Windows"
|
||||
if self.config.fourcc is not None and not set_fourcc_after_size_and_fps:
|
||||
self._validate_fourcc()
|
||||
|
||||
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
|
||||
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
||||
|
||||
@@ -222,6 +223,11 @@ class OpenCVCamera(Camera):
|
||||
else:
|
||||
self._validate_fps()
|
||||
|
||||
if self.config.fourcc is not None and set_fourcc_after_size_and_fps:
|
||||
# On Windows with DSHOW, changing the resolution can silently override the FOURCC setting.
|
||||
# Set FOURCC last to make sure the requested pixel format is actually enforced.
|
||||
self._validate_fourcc()
|
||||
|
||||
def _validate_fps(self) -> None:
|
||||
"""Validates and sets the camera's frames per second (FPS)."""
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -24,6 +24,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -31,6 +32,12 @@ from .types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from .video import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VIDEO_ENCODER_INFO_KEYS,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
@@ -43,7 +50,16 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
"VideoEncoderConfig",
|
||||
# Defaults
|
||||
"camera_encoder_defaults",
|
||||
# Constants
|
||||
"VALID_VIDEO_CODECS",
|
||||
"VIDEO_ENCODER_INFO_KEYS",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from .video import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
@@ -39,8 +41,8 @@ class DatasetRecordConfig:
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||
private: bool | None = None
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
@@ -55,10 +57,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Video encoder settings for camera MP4s (codec, quality, GOP, etc.). Tuned via CLI nested keys,
|
||||
# e.g. ``--dataset.camera_encoder.vcodec=h264`` (see ``VideoEncoderConfig``).
|
||||
camera_encoder: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
@@ -117,3 +117,9 @@ class PeftConfig:
|
||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||
# fine-tuning.
|
||||
r: int = 16
|
||||
|
||||
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
|
||||
# In general, a higher alpha means stronger adaptation signal.
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
@@ -18,8 +18,8 @@ from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import envs, policies # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
|
||||
from . import parser
|
||||
from .default import EvalConfig
|
||||
from .policies import PreTrainedConfig
|
||||
|
||||
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
|
||||
else:
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import pkgutil
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentError
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import wraps
|
||||
@@ -24,6 +26,7 @@ from types import ModuleType
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import draccus
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
from lerobot.utils.utils import has_method
|
||||
|
||||
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
|
||||
# Storage for path args extracted from YAML/JSON config files, so that
|
||||
# get_path_arg() can find them even when they weren't passed via CLI.
|
||||
_config_path_args: dict[str, str] = {}
|
||||
|
||||
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
|
||||
_config_yaml_overrides: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
|
||||
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
|
||||
args = []
|
||||
for key, value in d.items():
|
||||
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
|
||||
continue
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, bool):
|
||||
value = str(value).lower()
|
||||
if isinstance(value, dict):
|
||||
args.extend(_flatten_to_cli_args(value, full_key))
|
||||
elif value is not None and not isinstance(value, list):
|
||||
args.append(f"--{full_key}={value}")
|
||||
return args
|
||||
|
||||
|
||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||
"""Parses arguments from cli at a given nested attribute level.
|
||||
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
if result is None:
|
||||
result = _config_path_args.get(field_name)
|
||||
return result
|
||||
|
||||
|
||||
def get_yaml_overrides(field_name: str) -> list[str]:
|
||||
return _config_yaml_overrides.get(field_name, [])
|
||||
|
||||
|
||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
@@ -192,6 +225,51 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
return filtered_args
|
||||
|
||||
|
||||
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
|
||||
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
|
||||
|
||||
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
|
||||
draccus will fail because ``path`` is not a valid field on policy config classes.
|
||||
This function extracts those path values, stores them in ``_config_path_args`` for
|
||||
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
|
||||
"""
|
||||
config_file = Path(config_path)
|
||||
suffix = config_file.suffix.lower()
|
||||
|
||||
if suffix in (".yaml", ".yml"):
|
||||
with open(config_file) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
elif suffix == ".json":
|
||||
with open(config_file) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
return config_path
|
||||
|
||||
if not isinstance(config_data, dict):
|
||||
return config_path
|
||||
|
||||
modified = False
|
||||
for field in path_fields:
|
||||
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
|
||||
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
|
||||
remaining = config_data[field]
|
||||
if remaining:
|
||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||
del config_data[field]
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
return config_path
|
||||
|
||||
# Write cleaned config to a temp file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
||||
if suffix in (".yaml", ".yml"):
|
||||
yaml.dump(config_data, tmp, default_flow_style=False)
|
||||
else:
|
||||
json.dump(config_data, tmp, indent=2)
|
||||
return tmp.name
|
||||
|
||||
|
||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
@@ -225,11 +303,20 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
cli_args = filter_path_args(path_fields, cli_args)
|
||||
# Also extract path fields from the YAML/JSON config file
|
||||
if config_path_cli:
|
||||
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
|
||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||
else:
|
||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
||||
if config_path_cli:
|
||||
cli_args = filter_arg("config_path", cli_args)
|
||||
cfg = draccus.parse(
|
||||
config_class=argtype,
|
||||
config_path=config_path_cli or config_path,
|
||||
args=cli_args,
|
||||
)
|
||||
response = fn(cfg, *args, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -27,12 +27,13 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .types import PolicyFeature
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,9 +90,9 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
def get_optimizer_preset(self) -> OptimizerConfig | None:
|
||||
"""Default optimizer for this reward model, or ``None`` for zero-shot models."""
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
@@ -25,11 +25,11 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
@@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
config_path = parser.parse_arg("config_path")
|
||||
@@ -174,6 +177,12 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
@@ -269,10 +278,3 @@ class TrainPipelineConfig(HubMixin):
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Note: We subclass str so that serialization is straightforward
|
||||
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
|
||||
|
||||
"""Video encoder configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and the chosen video backend.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_VIDEO_CODECS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
|
||||
# Aliases for legacy video codec names.
|
||||
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
|
||||
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
# Keys persisted under ``features[*]["info"]`` as ``video.<name>`` (from :class:`VideoEncoderConfig`).
|
||||
# ``vcodec``` and ``pix_fmt`` are derived from the video stream directly.
|
||||
VIDEO_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{"g", "crf", "preset", "fast_decode", "extra_options", "video_backend"}
|
||||
)
|
||||
VIDEO_ENCODER_INFO_KEYS: frozenset[str] = frozenset(
|
||||
f"video.{name}" for name in VIDEO_ENCODER_INFO_FIELD_NAMES
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
|
||||
Attributes:
|
||||
vcodec: Video encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python to be used for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional video encoder options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
|
||||
vcodec: str = "libsvtav1" # TODO(CarolinePascal): rename to codec ?
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | float | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
|
||||
for src_key, dst_field in (("video.codec", "vcodec"), ("video.pix_fmt", "pix_fmt")):
|
||||
value = video_info.get(src_key)
|
||||
if value is not None:
|
||||
kwargs[dst_field] = value
|
||||
|
||||
for field_name in VIDEO_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{field_name}")
|
||||
if value is None:
|
||||
continue
|
||||
# Persisted as ``{}`` after merges with disagreeing sources — treat as default.
|
||||
if field_name == "extra_options" and not value:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
|
||||
Args:
|
||||
encoders: List of encoder names to detect. If a string, it is converted to a list.
|
||||
Returns:
|
||||
List of available encoder names. If the video backend is not "pyav", returns an empty list.
|
||||
"""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import detect_available_encoders_pyav
|
||||
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder configuration."""
|
||||
if self.video_backend == "pyav":
|
||||
require_package("av", extra="dataset")
|
||||
from lerobot.datasets import check_video_encoder_parameters_pyav
|
||||
|
||||
check_video_encoder_parameters_pyav(self.vcodec, self.pix_fmt, self.get_codec_options())
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Check ``vcodec`` and, when it is ``"auto"``, pick a concrete encoder.
|
||||
|
||||
For ``"auto"``, the first hardware encoder in the preference list that is available is chosen; if none are available, ``libsvtav1`` is used. If the
|
||||
resolved codec (explicit or after auto-selection) is not available, raises ``ValueError``.
|
||||
|
||||
Stream-derived canonical codec names listed in :data:`VIDEO_CODECS_ALIASES` are
|
||||
rewritten to their corresponding encoder name (e.g. ``"av1"`` → ``"libsvtav1"``).
|
||||
"""
|
||||
self.vcodec = VIDEO_CODECS_ALIASES.get(self.vcodec, self.vcodec)
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_VIDEO_CODECS)
|
||||
for encoder in HW_VIDEO_CODECS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.warning("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Translate the tuning fields to codec-specific options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = 0
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
@@ -31,15 +31,25 @@ from .dataset_tools import (
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
recompute_stats,
|
||||
reencode_dataset,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
@@ -53,12 +63,19 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"check_video_encoder_parameters_pyav",
|
||||
"detect_available_encoders_pyav",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
@@ -66,6 +83,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
@@ -74,6 +92,7 @@ __all__ = [
|
||||
"modify_features",
|
||||
"modify_tasks",
|
||||
"recompute_stats",
|
||||
"reencode_dataset",
|
||||
"remove_feature",
|
||||
"resolve_delta_timestamps",
|
||||
"safe_stop_image_writer",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -23,9 +24,11 @@ import datasets
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
|
||||
from .compute_stats import aggregate_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import get_hf_features_from_features
|
||||
from .feature_utils import features_equal_for_merge, get_hf_features_from_features
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
@@ -46,11 +49,54 @@ from .utils import (
|
||||
from .video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
def merge_video_feature_info_for_aggregate(all_metadata: list[LeRobotDatasetMetadata]) -> dict[str, dict]:
|
||||
"""Create a merged video feature info dictionary for aggregation. The video encoder info is merged field-by-field: each key is kept only when every source agrees; otherwise that key is set to ``null`` (or ``{}`` for ``video.extra_options``) and a warning is logged.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to merge.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of merged video feature info.
|
||||
"""
|
||||
merged_info = copy.deepcopy(all_metadata[0].features)
|
||||
video_keys = [k for k in merged_info if merged_info[k].get("dtype") == "video"]
|
||||
|
||||
for vk in video_keys:
|
||||
video_infos = [m.features.get(vk, {}).get("info") or {} for m in all_metadata]
|
||||
base_video_info = video_infos[0]
|
||||
|
||||
merged_encoder_info: dict = {}
|
||||
fallback_keys: list[str] = []
|
||||
for info_key in VIDEO_ENCODER_INFO_KEYS:
|
||||
values = [info.get(info_key, None) for info in video_infos]
|
||||
first_value = values[0]
|
||||
all_match = all(v == first_value for v in values[1:])
|
||||
|
||||
if all_match:
|
||||
merged_encoder_info[info_key] = first_value
|
||||
else:
|
||||
fallback_keys.append(info_key)
|
||||
merged_encoder_info[info_key] = {} if info_key == "video.extra_options" else None
|
||||
|
||||
if fallback_keys:
|
||||
logging.warning(
|
||||
f"Merging heterogeneous or incomplete video encoder metadata for feature {vk}. "
|
||||
f"Setting these keys to null: {fallback_keys}.",
|
||||
)
|
||||
|
||||
merged_info[vk]["info"] = {**base_video_info, **merged_encoder_info}
|
||||
# TODO(CarolinePascal): make this variable once we have support for other video backends.
|
||||
merged_info[vk]["info"]["video.video_backend"] = "pyav"
|
||||
|
||||
return merged_info
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
"""Validates that all dataset metadata have consistent properties.
|
||||
|
||||
Ensures all datasets have the same fps, robot_type, and features to guarantee
|
||||
compatibility when aggregating them into a single dataset.
|
||||
Video encoder info is not considered for validation but is merged during aggregation in ``merge_video_feature_info_for_aggregate``.
|
||||
|
||||
Args:
|
||||
all_metadata: List of LeRobotDatasetMetadata objects to validate.
|
||||
@@ -74,7 +120,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
if not features_equal_for_merge(features, meta.features):
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
@@ -274,7 +320,8 @@ def aggregate_datasets(
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
fps, robot_type, _ = validate_all_metadata(all_metadata)
|
||||
features = merge_video_feature_info_for_aggregate(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -332,7 +379,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -414,9 +460,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
# TODO(CarolinePascal): Move the check before the loop to avoid failing in the middle + add possibility to re-encode the video if the check fails
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -23,6 +24,7 @@ import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
from lerobot.utils.feature_utils import _validate_feature_names
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
@@ -34,12 +36,12 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -175,7 +177,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -189,6 +190,29 @@ class LeRobotDatasetMetadata:
|
||||
if self.episodes is None:
|
||||
self._load_metadata()
|
||||
|
||||
def filter_episodes(
|
||||
self,
|
||||
predicate: Callable[[dict], bool],
|
||||
candidates: list[int] | None = None,
|
||||
) -> list[int]:
|
||||
"""Filter episodes whose metadata satisfies a given predicate.
|
||||
|
||||
Args:
|
||||
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||
candidates: Optional list of episode indices to restrict evaluation to.
|
||||
|
||||
Returns:
|
||||
List of sorted episode indices that satisfy the predicate.
|
||||
"""
|
||||
self.ensure_readable()
|
||||
if candidates is not None:
|
||||
candidate_set = set(candidates)
|
||||
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
||||
else:
|
||||
combined = predicate
|
||||
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
||||
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||
|
||||
def _pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
@@ -318,6 +342,49 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -510,10 +577,23 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
@@ -522,7 +602,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
self.info.features[key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -633,7 +713,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -26,7 +26,7 @@ This module provides utilities for:
|
||||
import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -36,6 +36,7 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
|
||||
@@ -60,9 +61,14 @@ from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
VIDEO_DIR,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
from .video_utils import (
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
reencode_video,
|
||||
)
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -95,6 +101,11 @@ def delete_episodes(
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes kept and
|
||||
deleted episodes) are re-encoded with the source dataset's existing encoder
|
||||
settings — read back from ``meta/info.json`` — so the output dataset stays
|
||||
consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_indices: List of episode indices to delete.
|
||||
@@ -157,6 +168,11 @@ def split_dataset(
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
Video segments that need re-encoding (because the source file mixes episodes
|
||||
that fall into different splits) are re-encoded with the source dataset's
|
||||
existing encoder settings — read back from ``meta/info.json`` — so each
|
||||
output split stays consistent with its own metadata.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to split.
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
@@ -578,8 +594,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -593,8 +608,7 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
camera_encoder: Video encoder settings used to re-encode the kept frames.
|
||||
"""
|
||||
from fractions import Fraction
|
||||
|
||||
@@ -619,12 +633,13 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
codec_options = camera_encoder.get_codec_options(as_strings=True)
|
||||
v_out = out.add_stream(camera_encoder.vcodec, rate=fps_fraction, options=codec_options)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = pix_fmt
|
||||
v_out.pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -687,14 +702,14 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
For video files that only contain kept episodes, we copy them directly.
|
||||
For files with mixed kept/deleted episodes, we use PyAV filters to efficiently
|
||||
re-encode only the desired segments.
|
||||
re-encode only the desired segments. The encoder used for re-encoding is
|
||||
derived per video key from the source dataset's ``meta/info.json`` so the
|
||||
destination metadata keeps describing the videos accurately.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset to copy from
|
||||
@@ -711,6 +726,9 @@ def _copy_and_reindex_videos(
|
||||
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
logging.info(f"Processing videos for {video_key}")
|
||||
camera_encoder = VideoEncoderConfig.from_video_info(
|
||||
src_dataset.meta.info.features.get(video_key, {}).get("info")
|
||||
)
|
||||
|
||||
if dst_meta.video_path is None:
|
||||
raise ValueError("Destination metadata has no video_path defined")
|
||||
@@ -792,8 +810,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
camera_encoder,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1264,11 +1281,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1282,11 +1295,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
camera_encoder: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1322,11 +1331,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1644,11 +1649,7 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1663,11 +1664,8 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
camera_encoder: Video encoder settings
|
||||
(``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1676,6 +1674,9 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1699,7 +1700,10 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder.vcodec}, pixel format: {camera_encoder.pix_fmt}, "
|
||||
f"GOP: {camera_encoder.g}, CRF: {camera_encoder.crf}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1769,11 +1773,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1815,11 +1815,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder=camera_encoder,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1865,7 +1861,9 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder=camera_encoder
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
@@ -1888,3 +1886,83 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _reencode_video_worker(args: tuple) -> Path:
|
||||
"""Picklable worker for :func:`reencode_dataset`'s process pool."""
|
||||
video_path, camera_encoder, encoder_threads = args
|
||||
reencode_video(
|
||||
input_video_path=video_path,
|
||||
output_video_path=video_path,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def reencode_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
camera_encoder: VideoEncoderConfig,
|
||||
encoder_threads: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Re-encode every video in a dataset with a new set of encoding parameters.
|
||||
|
||||
Videos are re-encoded in-place and the video information in ``info.json`` is refreshed.
|
||||
|
||||
Args:
|
||||
dataset: An existing :class:`LeRobotDataset` whose videos will be
|
||||
re-encoded.
|
||||
camera_encoder: Target encoder configuration applied to every video
|
||||
file.
|
||||
encoder_threads: Per-encoder thread count forwarded to
|
||||
:func:`reencode_video`. ``None`` lets the codec decide.
|
||||
num_workers: Number of parallel processes. ``None`` or ``0`` means
|
||||
sequential (no multiprocessing); ``1+`` spawns a
|
||||
:class:`~concurrent.futures.ProcessPoolExecutor`.
|
||||
|
||||
Returns:
|
||||
The same :class:`LeRobotDataset` instance with its metadata updated
|
||||
on disk.
|
||||
"""
|
||||
meta = dataset.meta
|
||||
video_paths_list = []
|
||||
|
||||
# Only re-encode if the videos are not already encoded with the given video encoding parameters
|
||||
for video_key in meta.video_keys:
|
||||
current_info = meta.info.features[video_key].get("info", {})
|
||||
current_encoder = VideoEncoderConfig.from_video_info(current_info)
|
||||
if current_encoder != camera_encoder:
|
||||
video_paths_list.extend((meta.root / VIDEO_DIR / video_key).rglob("*.mp4"))
|
||||
else:
|
||||
logging.info(f"{video_key} videos are already encoded with {camera_encoder}. Nothing to do.")
|
||||
|
||||
if len(video_paths_list) == 0:
|
||||
logging.warning("Dataset has no videos to re-encode.")
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {len(video_paths_list)} video file(s) with {camera_encoder}")
|
||||
|
||||
worker_args = [(vp, camera_encoder, encoder_threads) for vp in video_paths_list]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
for future in tqdm(
|
||||
as_completed(futures),
|
||||
total=len(futures),
|
||||
desc="Re-encoding videos",
|
||||
):
|
||||
future.result()
|
||||
else:
|
||||
for args in tqdm(worker_args, desc="Re-encoding videos"):
|
||||
_reencode_video_worker(args)
|
||||
|
||||
# Refresh video info in metadata for every video key.
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(0, vid_key)
|
||||
meta.info.features[vid_key]["info"] = get_video_info(video_path, camera_encoder=camera_encoder)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -31,6 +31,8 @@ import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults
|
||||
|
||||
from .compute_stats import compute_episode_stats
|
||||
from .dataset_metadata import LeRobotDatasetMetadata
|
||||
from .feature_utils import (
|
||||
@@ -65,14 +67,19 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -89,20 +96,22 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
``None`` uses :func:`~lerobot.configs.camera_encoder_defaults`.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
@@ -111,7 +120,7 @@ class DatasetWriter:
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._vcodec = vcodec
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -241,7 +250,14 @@ class DatasetWriter:
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
stacked_values = np.stack(episode_buffer[key])
|
||||
|
||||
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||
stacked_values = stacked_values.reshape(episode_length)
|
||||
|
||||
episode_buffer[key] = stacked_values
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
@@ -284,7 +300,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._vcodec,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -495,7 +511,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key)
|
||||
self._meta.update_video_info(video_key, camera_encoder=self._camera_encoder)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -564,7 +580,12 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
@@ -13,15 +13,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs import VIDEO_ENCODER_INFO_KEYS
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -46,7 +54,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -108,6 +122,41 @@ def create_empty_dataset_info(
|
||||
)
|
||||
|
||||
|
||||
def features_equal_for_merge(features_a: dict[str, dict], features_b: dict[str, dict]) -> bool:
|
||||
"""Return whether two LeRobotDatasetMetadata ``features`` dicts are compatible for aggregation.
|
||||
|
||||
For video features, keys under ``info`` related to video encoding parameters are ignored during
|
||||
comparison as they do not prevent aggregation.
|
||||
"""
|
||||
|
||||
def _without_encoder_info_keys(feature: dict) -> dict:
|
||||
filtered = dict(feature)
|
||||
filtered_info = filtered.get("info")
|
||||
if isinstance(filtered_info, dict):
|
||||
filtered["info"] = {
|
||||
info_key: info_value
|
||||
for info_key, info_value in filtered_info.items()
|
||||
if info_key not in VIDEO_ENCODER_INFO_KEYS
|
||||
}
|
||||
return filtered
|
||||
|
||||
if set(features_a) != set(features_b):
|
||||
return False
|
||||
for key in features_a:
|
||||
fa_key = features_a[key]
|
||||
fb_key = features_b[key]
|
||||
if fa_key.get("dtype") != fb_key.get("dtype"):
|
||||
return False
|
||||
if fa_key.get("dtype") != "video":
|
||||
if fa_key != fb_key:
|
||||
return False
|
||||
continue
|
||||
|
||||
if _without_encoder_info_keys(fa_key) != _without_encoder_info_keys(fb_key):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
@@ -242,6 +291,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -321,6 +372,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -31,10 +31,10 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -186,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -265,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -304,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
@@ -0,0 +1,545 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -24,6 +24,7 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig
|
||||
from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
@@ -36,8 +37,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
get_safe_default_video_backend,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
episode_filter: Callable[[dict], bool] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
@@ -58,10 +59,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||
Defaults to None.
|
||||
image_transforms (Callable | None, optional):
|
||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||
conversion. This works for both image-backed and video-backed observations and can later be
|
||||
@@ -177,16 +183,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
camera_encoder (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults`
|
||||
is used by the writer.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -199,13 +204,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -218,6 +221,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
|
||||
if episodes is not None and any(
|
||||
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||
):
|
||||
logger.warning(
|
||||
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||
)
|
||||
|
||||
if episode_filter is not None:
|
||||
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||
)
|
||||
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||
episodes = resolved
|
||||
self.episodes = episodes
|
||||
|
||||
# Create reader (hf_dataset loaded below)
|
||||
self.reader = DatasetReader(
|
||||
meta=self.meta,
|
||||
@@ -251,12 +271,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
self.meta.fps,
|
||||
camera_encoder,
|
||||
encoder_queue_maxsize,
|
||||
encoder_threads,
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
vcodec=self._vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -298,17 +321,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
camera_encoder: VideoEncoderConfig | None,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
camera_encoder=camera_encoder,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
@@ -505,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
private: bool | None = None,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
@@ -524,7 +543,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tag_version: If ``True``, create a Git tag for the current codebase
|
||||
version.
|
||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||
private: If ``True``, create a private repository.
|
||||
private: If ``True``, create a private repository. If ``None``
|
||||
(default), defer to the org default on the Hub (only affects orgs).
|
||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||
of ``upload_folder`` for very large datasets.
|
||||
@@ -625,7 +645,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -656,20 +676,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -690,23 +710,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -729,12 +749,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -757,13 +777,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
vcodec: Video codec for encoding.
|
||||
camera_encoder: Video encoder settings for cameras (codec, quality, etc.).
|
||||
When ``None``, :func:`~lerobot.configs.video.camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -774,7 +796,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -783,11 +804,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -796,21 +815,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps, camera_encoder, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder=camera_encoder,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import av
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
_check_option_value(vcodec, key, value, supported_options[key])
|
||||
|
||||
|
||||
def check_video_encoder_parameters_pyav(vcodec: str, pix_fmt: str, codec_options: dict[str, Any]) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.configs.video.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
raise ValueError(f"Codec {vcodec!r} is not available in the bundled FFmpeg build")
|
||||
_check_pixel_format(vcodec, pix_fmt)
|
||||
_check_codec_options(vcodec, codec_options)
|
||||
@@ -88,7 +88,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -130,6 +129,9 @@ class DatasetInfo:
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
@@ -151,11 +153,15 @@ class DatasetInfo:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
||||
+308
-222
@@ -17,12 +17,14 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -33,90 +35,17 @@ import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.configs import (
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
@@ -132,7 +61,9 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||
|
||||
@@ -142,88 +73,90 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||
)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
def decode_video_frames_pyav(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
return_uint8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
|
||||
x86_64 and linux armv7l — see the torchcodec block in pyproject.toml for the full matrix).
|
||||
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
|
||||
accurate seek.
|
||||
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
|
||||
forward until we have covered the requested timestamp range. The number of key frames in a
|
||||
video can be adjusted at encoding time to trade off decoding speed against file size.
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps (in seconds) to extract frames for.
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||
[0, 1] range.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
Returns:
|
||||
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||
"""
|
||||
video_path = str(video_path)
|
||||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
video_path = str(video_path)
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
# load all frames until last requested frame
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
|
||||
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
|
||||
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
|
||||
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
container.seek(int(first_ts * av.time_base), backward=True)
|
||||
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
reader = None
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
@@ -234,14 +167,14 @@ def decode_video_frames_torchvision(
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
f"\nbackend: pyav"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
closest_ts = loaded_ts_t[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"{closest_ts=}")
|
||||
@@ -260,15 +193,70 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
DEFAULT_DECODER_CACHE_SIZE = 100
|
||||
"""Default LRU capacity for :class:`VideoDecoderCache`.
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
Sized to comfortably hold a small rolling window of episodes worth of decoders
|
||||
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
|
||||
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
|
||||
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
|
||||
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
|
||||
to the constructor (``None`` restores the legacy unbounded behaviour).
|
||||
"""
|
||||
|
||||
|
||||
def _default_max_cache_size() -> int | None:
|
||||
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
|
||||
if raw is None:
|
||||
return DEFAULT_DECODER_CACHE_SIZE
|
||||
raw = raw.strip().lower()
|
||||
if raw in ("", "none", "unbounded", "-1"):
|
||||
return None
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
|
||||
) from e
|
||||
if value <= 0:
|
||||
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
|
||||
return value
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
|
||||
|
||||
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
|
||||
backing it. When the cache is full and a new path is requested, the
|
||||
least-recently-used entry is evicted and its file handle is closed. This
|
||||
bounds host-RAM growth when iterating over datasets with many distinct
|
||||
video files (otherwise each ``DataLoader`` worker pins every decoder it has
|
||||
ever opened until the process exits).
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of decoders to retain. ``None`` disables
|
||||
eviction and restores legacy unbounded behaviour. Defaults to the
|
||||
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
|
||||
:data:`DEFAULT_DECODER_CACHE_SIZE`.
|
||||
"""
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
raise ValueError(f"max_size must be positive or None; got {max_size}")
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
return str(video_path) in self._cache
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
@@ -280,22 +268,36 @@ class VideoDecoderCache:
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
return self._cache[video_path][0]
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
# Evict LRU entries until we are back under the cap. We close
|
||||
# evicted file handles immediately; the associated ``VideoDecoder``
|
||||
# is released to the GC when its last reference goes away.
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
return decoder
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
"""Clear the cache and close all file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
with contextlib.suppress(Exception):
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
@@ -404,18 +406,17 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder is None:
|
||||
camera_encoder = camera_encoder_defaults()
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -426,42 +427,18 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -497,8 +474,97 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def reencode_video(
|
||||
input_video_path: Path | str,
|
||||
output_video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Re-encode a video file using the given encoder configuration.
|
||||
|
||||
Args:
|
||||
input_video_path: Existing video file to read.
|
||||
output_video_path: Path for the re-encoded file.
|
||||
camera_encoder: Encoder configuration. Defaults to :func:`camera_encoder_defaults`.
|
||||
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
|
||||
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
|
||||
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
|
||||
"""
|
||||
|
||||
camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping re-encode.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_options = camera_encoder.get_codec_options(encoder_threads, as_strings=True)
|
||||
vcodec = camera_encoder.vcodec
|
||||
pix_fmt = camera_encoder.pix_fmt
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
if log_level is not None:
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
try:
|
||||
with av.open(input_video_path, mode="r") as src:
|
||||
try:
|
||||
in_stream = src.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise ValueError(f"No video stream in {input_video_path}") from e
|
||||
|
||||
fps = (
|
||||
in_stream.base_rate
|
||||
) # We allow fractional fps though LeRobotDataset only supports integer fps
|
||||
width = int(in_stream.width)
|
||||
height = int(in_stream.height)
|
||||
|
||||
with av.open(
|
||||
tmp_output_video_path,
|
||||
mode="w",
|
||||
options={
|
||||
"movflags": "faststart"
|
||||
}, # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
) as dst:
|
||||
out_stream = dst.add_stream(vcodec, fps, options=video_options)
|
||||
out_stream.pix_fmt = pix_fmt
|
||||
out_stream.width = width
|
||||
out_stream.height = height
|
||||
|
||||
for frame in src.decode(in_stream):
|
||||
frame = frame.reformat(width=width, height=height, format=pix_fmt)
|
||||
packet = out_stream.encode(frame)
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
packet = out_stream.encode()
|
||||
if packet:
|
||||
dst.mux(packet)
|
||||
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
except Exception:
|
||||
Path(tmp_output_video_path).unlink(missing_ok=True)
|
||||
raise
|
||||
finally:
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_video_path.exists():
|
||||
raise OSError(f"Video re-encoding did not work. File not found: {output_video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -511,6 +577,7 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -529,6 +596,22 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -595,26 +678,20 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -650,19 +727,9 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
@@ -728,22 +795,24 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder: Video encoder settings applied to all cameras.
|
||||
When ``None``, :func:`camera_encoder_defaults` is used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
"""
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self._camera_encoder = camera_encoder or camera_encoder_defaults()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -774,18 +843,17 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder.vcodec
|
||||
codec_options = self._camera_encoder.get_codec_options(self._encoder_threads, as_strings=True)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -990,8 +1058,18 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1022,6 +1100,14 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder is not None:
|
||||
for field_name, field_value in asdict(camera_encoder).items():
|
||||
# vcodec is already populated from the video stream
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
|
||||
@@ -24,12 +24,7 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
IsaaclabArenaProcessorStep,
|
||||
LiberoActionProcessorStep,
|
||||
LiberoProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.constants import (
|
||||
@@ -128,7 +123,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], **extra_kwargs)
|
||||
return {self.type: {0: vec}}
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
def get_env_processors(self):
|
||||
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
|
||||
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
|
||||
|
||||
@@ -441,13 +436,10 @@ class LiberoEnv(EnvConfig):
|
||||
is_libero_plus=self.is_libero_plus,
|
||||
)
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
max_state_dim = getattr(policy_cfg, "max_state_dim", None) if getattr(policy_cfg, "type", None) == "evo1" else None
|
||||
action_feature = self.features.get(ACTION)
|
||||
action_dim = int(action_feature.shape[0]) if action_feature is not None else 7
|
||||
def get_env_processors(self):
|
||||
return (
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep(max_state_dim=max_state_dim)]),
|
||||
PolicyProcessorPipeline(steps=[LiberoActionProcessorStep(action_dim=action_dim)]),
|
||||
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
|
||||
@@ -713,7 +705,7 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def get_env_processors(self, policy_cfg: Any | None = None):
|
||||
def get_env_processors(self):
|
||||
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
|
||||
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
|
||||
if not state_keys and not camera_keys:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
@@ -53,14 +52,7 @@ def make_env_pre_post_processors(
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
get_processors = env_cfg.get_env_processors
|
||||
signature = inspect.signature(get_processors)
|
||||
supports_policy_cfg = "policy_cfg" in signature.parameters or any(
|
||||
param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||
)
|
||||
if supports_policy_cfg:
|
||||
return get_processors(policy_cfg=policy_cfg)
|
||||
return get_processors()
|
||||
return env_cfg.get_env_processors()
|
||||
|
||||
|
||||
def make_env(
|
||||
|
||||
@@ -18,12 +18,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
_placo_runtime_error: ImportError | None = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
try:
|
||||
import placo # type: ignore[import-not-found]
|
||||
except ImportError as _placo_import_err:
|
||||
placo = None
|
||||
_placo_runtime_error = _placo_import_err
|
||||
|
||||
|
||||
def _raise_if_placo_unusable() -> None:
|
||||
if placo is None and _placo_runtime_error is not None:
|
||||
raise ImportError(
|
||||
f"placo is installed but failed to import: {_placo_runtime_error!s}"
|
||||
) from _placo_runtime_error
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
@@ -44,6 +57,7 @@ class RobotKinematics:
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
require_package("placo", extra="placo-dep")
|
||||
_raise_if_placo_unusable()
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -43,6 +43,7 @@ from .tables import (
|
||||
CAN_CMD_SET_ZERO,
|
||||
DEFAULT_BAUDRATE,
|
||||
DEFAULT_TIMEOUT_MS,
|
||||
HANDSHAKE_TIMEOUT_S,
|
||||
MODEL_RESOLUTION,
|
||||
MOTOR_LIMIT_PARAMS,
|
||||
NORMALIZED_DATA,
|
||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._is_connected = False
|
||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||
|
||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
||||
def _query_status_via_clear_fault(
|
||||
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||
) -> tuple[bool, can.Message | None]:
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||
|
||||
def _recv_status_via_clear_fault(
|
||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
faulted_motors = []
|
||||
|
||||
for motor_name in self.motors:
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
||||
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||
if msg is None:
|
||||
missing_motors.append(motor_name)
|
||||
elif has_fault:
|
||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
|
||||
return responses
|
||||
|
||||
def _recv_all_messages_until_quiet(
|
||||
self,
|
||||
*,
|
||||
timeout: float = RUNNING_TIMEOUT,
|
||||
max_messages: int = 4096,
|
||||
) -> list[can.Message]:
|
||||
"""
|
||||
Receive frames until the bus goes quiet.
|
||||
|
||||
Args:
|
||||
timeout: Poll timeout used for each recv() call. Collection stops
|
||||
when one recv() times out (quiet gap).
|
||||
max_messages: Safety cap to prevent unbounded loops.
|
||||
"""
|
||||
out: list[can.Message] = []
|
||||
max_messages = max(1, max_messages)
|
||||
timeout = max(0.0, timeout)
|
||||
|
||||
try:
|
||||
while len(out) < max_messages:
|
||||
msg = self._bus().recv(timeout=timeout)
|
||||
if msg is None:
|
||||
break
|
||||
out.append(msg)
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||
|
||||
return out
|
||||
|
||||
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||
"""
|
||||
Decode all received feedback frames and update cached motor states.
|
||||
|
||||
Returns:
|
||||
Set of payload recv_ids that were successfully mapped to motors.
|
||||
"""
|
||||
processed_recv_ids: set[int] = set()
|
||||
for msg in messages:
|
||||
if len(msg.data) < 1:
|
||||
logger.debug(
|
||||
f"Dropping short CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
recv_id = int(msg.data[0])
|
||||
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||
if motor_name is None:
|
||||
logger.debug(
|
||||
f"Unmapped CAN frame on {self.port} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||
)
|
||||
continue
|
||||
|
||||
self._process_response(motor_name, msg)
|
||||
processed_recv_ids.add(recv_id)
|
||||
|
||||
return processed_recv_ids
|
||||
|
||||
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||
"""
|
||||
Drain pending RX frames from the CAN interface.
|
||||
|
||||
This is used by higher-level controllers to drop stale feedback before issuing
|
||||
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||
It should also be called once when a controller instance is created/connected,
|
||||
to clear residual frames left on the interface from previous sessions.
|
||||
"""
|
||||
drained = 0
|
||||
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||
max_messages = max(1, max_messages)
|
||||
try:
|
||||
while drained < max_messages:
|
||||
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||
if msg is None:
|
||||
break
|
||||
drained += 1
|
||||
except (can.CanError, OSError) as e:
|
||||
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||
return drained
|
||||
|
||||
def _speed_control(
|
||||
self,
|
||||
motor: NameOrID,
|
||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
self._bus().send(msg)
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||
# This avoids dropping useful frames when responses from different motors interleave.
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
||||
for recv_id, motor_name in recv_id_to_motor.items():
|
||||
if msg := responses.get(recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||
"""Convert float to unsigned integer for CAN transmission."""
|
||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
try:
|
||||
self._decode_motor_state(msg.data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to decode response from {motor} "
|
||||
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||
)
|
||||
|
||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||
"""Retrieve a specific value from the state cache."""
|
||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
self._bus().send(msg)
|
||||
updated_motors.append(motor)
|
||||
|
||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
||||
|
||||
for response in responses.values():
|
||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
||||
if payload_motor_name is not None:
|
||||
self._process_response(payload_motor_name, response)
|
||||
else:
|
||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
||||
self._decode_motor_state(response.data)
|
||||
messages = self._recv_all_messages_until_quiet()
|
||||
processed_recv_ids = self._process_feedback_messages(messages)
|
||||
|
||||
for motor in updated_motors:
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
if recv_id not in responses:
|
||||
if recv_id not in processed_recv_ids:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
|
||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
RUNNING_TIMEOUT = 0.003
|
||||
HANDSHAKE_TIMEOUT_S = 0.05
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
|
||||
@@ -17,15 +17,15 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -33,22 +33,22 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"Evo1Config",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"SACConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -28,11 +28,12 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -44,6 +43,8 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_evo1_README.md
|
||||
@@ -1,225 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("evo1_exact")
|
||||
@dataclass
|
||||
class Evo1SchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step: int) -> float:
|
||||
if current_step < self.num_warmup_steps:
|
||||
return current_step / max(1, self.num_warmup_steps)
|
||||
progress = (current_step - self.num_warmup_steps) / max(
|
||||
1, num_training_steps - self.num_warmup_steps
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("evo1")
|
||||
@dataclass
|
||||
class Evo1Config(PreTrainedConfig):
|
||||
training_stage: str = "stage1"
|
||||
use_amp: bool = True
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
max_state_dim: int = 24
|
||||
max_action_dim: int = 24
|
||||
max_views: int = 3
|
||||
image_resolution: tuple[int, int] = (448, 448)
|
||||
empty_cameras: int = 0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
|
||||
vlm_num_layers: int | None = 14
|
||||
vlm_dtype: str = "bfloat16"
|
||||
use_flash_attn: bool = True
|
||||
action_head: str = "flowmatching"
|
||||
embed_dim: int = 896
|
||||
hidden_dim: int = 1024
|
||||
state_hidden_dim: int = 1024
|
||||
num_heads: int = 8
|
||||
num_layers: int = 8
|
||||
dropout: float = 0.0
|
||||
num_inference_timesteps: int = 32
|
||||
num_categories: int = 1
|
||||
return_cls_only: bool = False
|
||||
enable_gradient_checkpointing: bool = True
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
|
||||
finetune_vlm: bool | None = None
|
||||
finetune_language_model: bool | None = None
|
||||
finetune_vision_model: bool | None = None
|
||||
finetune_action_head: bool | None = None
|
||||
# Reapply stage defaults after loading checkpoint configs so stage2 cannot
|
||||
# accidentally inherit the frozen VLM flags stored by a stage1 checkpoint.
|
||||
apply_training_stage_defaults: bool = True
|
||||
|
||||
task_field: str = "task"
|
||||
embodiment_id_field: str | None = None
|
||||
default_embodiment_id: int = 0
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 300
|
||||
drop_last: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.training_stage not in {"stage1", "stage2"}:
|
||||
raise ValueError(
|
||||
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||
)
|
||||
|
||||
if self.apply_training_stage_defaults:
|
||||
if self.training_stage == "stage1":
|
||||
self.finetune_vlm = False
|
||||
self.finetune_language_model = False
|
||||
self.finetune_vision_model = False
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
self.finetune_vlm = True
|
||||
self.finetune_language_model = True
|
||||
self.finetune_vision_model = True
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
|
||||
)
|
||||
if not has_explicit_branch_flags:
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = True
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = True
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = True
|
||||
elif self.finetune_vlm is None:
|
||||
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = False
|
||||
|
||||
branch_vlm = self.finetune_language_model or self.finetune_vision_model
|
||||
if self.finetune_vlm != branch_vlm:
|
||||
raise ValueError(
|
||||
"Inconsistent EVO1 finetune config: "
|
||||
f"finetune_vlm={self.finetune_vlm} but "
|
||||
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
|
||||
"When branch-level flags are used, finetune_vlm must match their effective union."
|
||||
)
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.input_features is None:
|
||||
self.input_features = {}
|
||||
if self.output_features is None:
|
||||
self.output_features = {}
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution),
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return Evo1SchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,234 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
|
||||
|
||||
|
||||
def _cfgget(config: Any, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class EVO1(nn.Module):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device = _cfgget(config, "device", "cuda")
|
||||
self.return_cls_only = _cfgget(config, "return_cls_only", False)
|
||||
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
|
||||
image_size = _cfgget(config, "image_size", 448)
|
||||
if image_size is None:
|
||||
image_resolution = _cfgget(config, "image_resolution", (448, 448))
|
||||
image_size = int(image_resolution[0])
|
||||
|
||||
self.embedder = InternVL3Embedder(
|
||||
model_name=vlm_name,
|
||||
image_size=image_size,
|
||||
device=self._device,
|
||||
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
|
||||
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
|
||||
use_flash_attn=_cfgget(config, "use_flash_attn", True),
|
||||
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
|
||||
gradient_checkpointing_use_reentrant=_cfgget(
|
||||
config, "gradient_checkpointing_use_reentrant", False
|
||||
),
|
||||
)
|
||||
|
||||
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
|
||||
if action_head_type != "flowmatching":
|
||||
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
|
||||
|
||||
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
|
||||
per_action_dim = _cfgget(config, "per_action_dim", 7)
|
||||
action_dim = horizon * per_action_dim
|
||||
|
||||
if isinstance(config, dict):
|
||||
config["horizon"] = horizon
|
||||
config["per_action_dim"] = per_action_dim
|
||||
config["action_dim"] = action_dim
|
||||
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = per_action_dim
|
||||
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
|
||||
|
||||
def _normalize_image_batches(
|
||||
self,
|
||||
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
prompt: str | list[str] | None,
|
||||
image_mask: torch.Tensor,
|
||||
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
|
||||
if not images:
|
||||
raise ValueError("EVO1 expects at least one image per sample.")
|
||||
|
||||
first = images[0]
|
||||
if isinstance(first, (Image.Image, torch.Tensor)):
|
||||
image_batches = [list(images)] # type: ignore[arg-type]
|
||||
else:
|
||||
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
|
||||
|
||||
batch_size = len(image_batches)
|
||||
if prompt is None:
|
||||
prompts = [""] * batch_size
|
||||
elif isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = [str(p) for p in prompt]
|
||||
if len(prompts) != batch_size:
|
||||
raise ValueError(
|
||||
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
if image_mask.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
return image_batches, prompts, image_mask
|
||||
|
||||
def get_vl_embeddings(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str | list[str] | None = None,
|
||||
return_cls_only: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if return_cls_only is None:
|
||||
return_cls_only = self.return_cls_only
|
||||
|
||||
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
|
||||
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
|
||||
image_tensors_batch=image_batches,
|
||||
image_masks=image_mask,
|
||||
text_prompts=prompts,
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
|
||||
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(state_input, list):
|
||||
state_tensor = torch.tensor(state_input)
|
||||
elif isinstance(state_input, torch.Tensor):
|
||||
state_tensor = state_input
|
||||
else:
|
||||
raise TypeError(f"Unsupported state input type: {type(state_input)}")
|
||||
|
||||
if state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
return state_tensor.to(self._device)
|
||||
|
||||
def predict_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.action_head.get_action(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
return self.action_head(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str,
|
||||
state_input: list | torch.Tensor,
|
||||
return_cls_only: bool | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
|
||||
fused_tokens = self.get_vl_embeddings(
|
||||
images=[images],
|
||||
image_mask=image_mask,
|
||||
prompt=[prompt],
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
state_tensor = self.prepare_state(state_input)
|
||||
action = self.predict_action(
|
||||
fused_tokens,
|
||||
state_tensor,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
|
||||
action = action.to(torch.float32)
|
||||
return action
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
|
||||
|
||||
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = trainable
|
||||
|
||||
def set_finetune_flags(self):
|
||||
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
|
||||
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
|
||||
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (finetune_language_model, finetune_vision_model)
|
||||
)
|
||||
finetune_language_model = bool(finetune_language_model)
|
||||
finetune_vision_model = bool(finetune_vision_model)
|
||||
finetune_vlm = bool(finetune_vlm)
|
||||
|
||||
if has_explicit_branch_flags:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
if hasattr(self.embedder.model, "language_model"):
|
||||
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
|
||||
if hasattr(self.embedder.model, "vision_model"):
|
||||
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
|
||||
if hasattr(self.embedder.model, "mlp1"):
|
||||
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
|
||||
elif not finetune_vlm:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
|
||||
if not _cfgget(self.config, "finetune_action_head", False):
|
||||
self._set_module_trainable(self.action_head, False)
|
||||
@@ -1,456 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfgget(config, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int = 1000):
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, seq_len: int):
|
||||
if seq_len > self.pe.size(1):
|
||||
self._extend_pe(seq_len)
|
||||
return self.pe[:, :seq_len, :]
|
||||
|
||||
def _extend_pe(self, new_max_len):
|
||||
old_max_len, dim = self.pe.size(1), self.pe.size(2)
|
||||
if new_max_len <= old_max_len:
|
||||
return
|
||||
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
|
||||
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
|
||||
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
|
||||
extra_pe = extra_pe.unsqueeze(0)
|
||||
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
|
||||
self.pe = new_pe
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
if num_categories <= 1:
|
||||
self.linear = nn.Linear(in_dim, out_dim)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
if self.num_categories <= 1:
|
||||
if x.dtype != self.linear.weight.dtype:
|
||||
x = x.to(dtype=self.linear.weight.dtype)
|
||||
return self.linear(x)
|
||||
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.to(dtype=self.weight.dtype)
|
||||
|
||||
orig_shape = x.shape
|
||||
x_flat = x.reshape(-1, orig_shape[-1])
|
||||
if category_id.dim() == 0:
|
||||
cid = category_id.item()
|
||||
out = x_flat @ self.weight[cid] + self.bias[cid]
|
||||
else:
|
||||
category_id = category_id.reshape(-1)
|
||||
if category_id.numel() != x_flat.size(0):
|
||||
raise ValueError(
|
||||
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
|
||||
)
|
||||
weight_selected = self.weight[category_id]
|
||||
bias_selected = self.bias[category_id]
|
||||
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
|
||||
out_shape = orig_shape[:-1] + (out.shape[-1],)
|
||||
return out.view(out_shape)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
|
||||
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
out = self.activation(self.fc1(x, category_id))
|
||||
out = self.fc2(out, category_id)
|
||||
return out
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.embed_dim = embed_dim
|
||||
self.num_categories = num_categories
|
||||
|
||||
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
|
||||
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
|
||||
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
|
||||
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||
batch_size, horizon, action_dim = action_seq.shape
|
||||
assert self.horizon == horizon, "Action sequence length must match horizon"
|
||||
|
||||
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||
if category_id.dim() == 0:
|
||||
cat_ids = category_id.expand(horizon * batch_size)
|
||||
else:
|
||||
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
|
||||
|
||||
out = self.activation(self.W1(x, cat_ids))
|
||||
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
|
||||
out = out.view(batch_size, horizon, -1) + pos_enc
|
||||
out = out.view(batch_size * horizon, -1)
|
||||
out = self.activation(self.W2(out, cat_ids))
|
||||
out = self.W3(out, cat_ids)
|
||||
return out.view(batch_size, horizon, self.embed_dim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
||||
|
||||
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
|
||||
x = self.norm1(action_tokens)
|
||||
attn_out, _ = self.attn(x, context_tokens, context_tokens)
|
||||
x = action_tokens + attn_out
|
||||
x2 = self.norm2(x)
|
||||
if time_emb is not None:
|
||||
x2 = x2 + time_emb.unsqueeze(1)
|
||||
ff_out = self.ff(x2)
|
||||
return x + ff_out
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config=None,
|
||||
embed_dim: int = 896,
|
||||
hidden_dim: int = 1024,
|
||||
action_dim: int = 16 * 7,
|
||||
horizon: int = 16,
|
||||
per_action_dim: int = 7,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 8,
|
||||
dropout: float = 0.0,
|
||||
num_inference_timesteps: int = 20,
|
||||
num_categories: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is not None:
|
||||
embed_dim = _cfgget(config, "embed_dim", embed_dim)
|
||||
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
|
||||
action_dim = _cfgget(config, "action_dim", action_dim)
|
||||
horizon = _cfgget(config, "horizon", horizon)
|
||||
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
|
||||
num_heads = _cfgget(config, "num_heads", num_heads)
|
||||
num_layers = _cfgget(config, "num_layers", num_layers)
|
||||
dropout = _cfgget(config, "dropout", dropout)
|
||||
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
|
||||
num_categories = _cfgget(config, "num_categories", num_categories)
|
||||
self.config = config
|
||||
else:
|
||||
self.config = SimpleNamespace(
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
per_action_dim=per_action_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
num_inference_timesteps=num_inference_timesteps,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
|
||||
self.embed_dim = embed_dim
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
|
||||
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
|
||||
|
||||
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(embed_dim)
|
||||
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
|
||||
self.mlp_head = CategorySpecificMLP(
|
||||
input_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=action_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
self.state_encoder = None
|
||||
state_dim = _cfgget(self.config, "state_dim")
|
||||
if state_dim is not None:
|
||||
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
input_dim=state_dim,
|
||||
hidden_dim=state_hidden,
|
||||
output_dim=embed_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
if horizon > 1:
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.per_action_dim,
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=embed_dim,
|
||||
horizon=horizon,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
self.single_action_proj = None
|
||||
else:
|
||||
self.action_encoder = None
|
||||
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
|
||||
|
||||
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
|
||||
if self.horizon > 1 and self.action_encoder is not None:
|
||||
return self.action_encoder(action_seq, embodiment_id)
|
||||
if self.single_action_proj is None:
|
||||
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
|
||||
return self.single_action_proj(action_seq)
|
||||
|
||||
def _expand_action_mask(
|
||||
self,
|
||||
action_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
per_action_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
raise ValueError("action_mask must be provided for flow matching inference.")
|
||||
|
||||
if action_mask.dim() == 2:
|
||||
expected_last_dim = self.horizon * per_action_dim
|
||||
if action_mask.shape == (batch_size, expected_last_dim):
|
||||
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
|
||||
elif action_mask.shape == (batch_size, per_action_dim):
|
||||
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
|
||||
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
elif action_mask.dim() == 3:
|
||||
expected_shape = (batch_size, self.horizon, per_action_dim)
|
||||
if tuple(action_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
expanded_mask = action_mask
|
||||
else:
|
||||
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
|
||||
|
||||
return expanded_mask.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
actions_gt: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
state_mask: torch.Tensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.get_action(
|
||||
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
|
||||
)
|
||||
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
t = (
|
||||
torch.distributions.Beta(2, 2)
|
||||
.sample((batch_size,))
|
||||
.clamp(0.02, 0.98)
|
||||
.to(device)
|
||||
.to(dtype=self.dtype)
|
||||
)
|
||||
time_index = (t * 999).long().clamp_(0, 999)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
|
||||
|
||||
actions_gt_seq = actions_gt
|
||||
noise = torch.rand_like(actions_gt) * 2 - 1
|
||||
if action_mask is not None:
|
||||
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
|
||||
if action_mask.shape != noise.shape:
|
||||
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
|
||||
actions_gt_seq = actions_gt_seq * action_mask
|
||||
noise = noise * action_mask
|
||||
|
||||
if self.horizon > 1:
|
||||
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
|
||||
else:
|
||||
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
|
||||
t_broadcast = t.view(batch_size, 1, 1)
|
||||
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
|
||||
|
||||
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
|
||||
target_dtype = self.dtype
|
||||
action_tokens = action_tokens.to(dtype=target_dtype)
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred_velocity, noise
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
|
||||
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
|
||||
|
||||
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
action_mask = self._expand_action_mask(
|
||||
action_mask,
|
||||
batch_size=batch_size,
|
||||
per_action_dim=per_action_dim,
|
||||
device=action_seq.device,
|
||||
dtype=action_seq.dtype,
|
||||
)
|
||||
action_seq = action_seq * action_mask
|
||||
|
||||
target_dtype = self.dtype
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
|
||||
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
|
||||
if num_steps <= 0:
|
||||
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
for i in range(num_steps):
|
||||
t = i / num_steps
|
||||
time_index = min(int(t * 999), 999)
|
||||
time_emb = (
|
||||
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
|
||||
)
|
||||
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
action = action + dt * pred
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
return action_seq.reshape(batch_size, -1)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -1,435 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoTokenizer = None
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
|
||||
IMG_START_TOKEN = "<img>" # nosec B105
|
||||
IMG_END_TOKEN = "</img>" # nosec B105
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _patch_vision_encoder_checkpointing(encoder: nn.Module, use_reentrant: bool) -> None:
|
||||
if getattr(encoder, "_evo1_checkpoint_patch_applied", False):
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
return
|
||||
|
||||
original_forward = encoder.forward
|
||||
|
||||
def forward_with_checkpoint_kwargs(self, *args, **kwargs):
|
||||
original_checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
def checkpoint(function, *checkpoint_args, **checkpoint_kwargs):
|
||||
checkpoint_kwargs.setdefault("use_reentrant", self.gradient_checkpointing_use_reentrant)
|
||||
return original_checkpoint(function, *checkpoint_args, **checkpoint_kwargs)
|
||||
|
||||
torch.utils.checkpoint.checkpoint = checkpoint
|
||||
try:
|
||||
return original_forward(*args, **kwargs)
|
||||
finally:
|
||||
torch.utils.checkpoint.checkpoint = original_checkpoint
|
||||
|
||||
encoder.gradient_checkpointing_use_reentrant = use_reentrant
|
||||
encoder.forward = types.MethodType(forward_with_checkpoint_kwargs, encoder)
|
||||
encoder._evo1_checkpoint_patch_applied = True
|
||||
|
||||
|
||||
def flash_attn_is_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _internvl_transformers5_load_compatibility():
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
original_linspace = torch.linspace
|
||||
original_mark_tied = PreTrainedModel.mark_tied_weights_as_initialized
|
||||
|
||||
def linspace(*args, **kwargs):
|
||||
if kwargs.get("device") is None:
|
||||
kwargs["device"] = torch.device("cpu")
|
||||
return original_linspace(*args, **kwargs)
|
||||
|
||||
def mark_tied_weights_as_initialized(self, loading_info):
|
||||
if not hasattr(self, "all_tied_weights_keys"):
|
||||
self.all_tied_weights_keys = {}
|
||||
return original_mark_tied(self, loading_info)
|
||||
|
||||
torch.linspace = linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = mark_tied_weights_as_initialized
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.linspace = original_linspace
|
||||
PreTrainedModel.mark_tied_weights_as_initialized = original_mark_tied
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=10000)
|
||||
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = orig_width * orig_height
|
||||
for ratio in target_ratios:
|
||||
target_ar = ratio[0] / ratio[1]
|
||||
diff = abs(aspect_ratio - target_ar)
|
||||
if diff < best_ratio_diff:
|
||||
best_ratio_diff = diff
|
||||
best_ratio = ratio
|
||||
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
|
||||
target_width = image_size * ratio_w
|
||||
target_height = image_size * ratio_h
|
||||
blocks = ratio_w * ratio_h
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
processed_images.append(resized_img.crop(box))
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
processed_images.append(image.resize((image_size, image_size)))
|
||||
return processed_images
|
||||
|
||||
|
||||
class InternVL3Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name="OpenGVLab/InternVL3-1B",
|
||||
image_size=448,
|
||||
device="cuda",
|
||||
num_language_layers: int | None = 14,
|
||||
model_dtype: str | torch.dtype = "bfloat16",
|
||||
use_flash_attn: bool = True,
|
||||
enable_gradient_checkpointing: bool = True,
|
||||
gradient_checkpointing_use_reentrant: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._requested_device = device
|
||||
self.image_size = image_size
|
||||
self.num_language_layers = num_language_layers
|
||||
self.max_text_length = 1024
|
||||
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
|
||||
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
|
||||
|
||||
require_package("transformers", extra="evo1")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
||||
if isinstance(model_dtype, str):
|
||||
try:
|
||||
model_dtype = getattr(torch, model_dtype)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||
|
||||
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
|
||||
if use_flash_attn and not resolved_use_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||
|
||||
# InternVL3 remote code predates Transformers 5 post-init conventions:
|
||||
# it computes stochastic-depth scalars via torch.linspace(...).item()
|
||||
# while Transformers initializes under torch.device("meta"), and it
|
||||
# does not populate all_tied_weights_keys before loading finalization.
|
||||
with _internvl_transformers5_load_compatibility():
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
layers = self.model.language_model.model.layers
|
||||
else:
|
||||
layers = self.model.language_model.layers
|
||||
if self.num_language_layers is not None:
|
||||
layers = layers[: self.num_language_layers]
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
|
||||
else:
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
self.model.language_model.lm_head = torch.nn.Identity()
|
||||
|
||||
self._configure_memory_features()
|
||||
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
|
||||
def _configure_memory_features(self) -> None:
|
||||
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||
|
||||
if not self.enable_gradient_checkpointing:
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = False
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
elif hasattr(language_model, "gradient_checkpointing"):
|
||||
language_model.gradient_checkpointing = False
|
||||
if hasattr(language_model, "model"):
|
||||
inner = language_model.model
|
||||
if hasattr(inner, "gradient_checkpointing_disable"):
|
||||
inner.gradient_checkpointing_disable()
|
||||
elif hasattr(inner, "gradient_checkpointing"):
|
||||
inner.gradient_checkpointing = False
|
||||
return
|
||||
|
||||
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||
if module is None:
|
||||
return False
|
||||
if hasattr(module, "gradient_checkpointing_enable"):
|
||||
try:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
|
||||
except TypeError:
|
||||
module.gradient_checkpointing_enable()
|
||||
return True
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = True
|
||||
return True
|
||||
return False
|
||||
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
encoder = self.model.vision_model.encoder
|
||||
encoder.gradient_checkpointing = True
|
||||
_patch_vision_encoder_checkpointing(
|
||||
encoder, use_reentrant=self.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
enabled_any = True
|
||||
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "model"):
|
||||
enabled_any = _enable_ckpt(language_model.model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
if hasattr(self.model, "enable_input_require_grads"):
|
||||
self.model.enable_input_require_grads()
|
||||
|
||||
if enabled_any:
|
||||
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
|
||||
else:
|
||||
logger.warning(
|
||||
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||
)
|
||||
|
||||
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(image, torch.Tensor):
|
||||
pil_image = to_pil_image(image.detach().cpu())
|
||||
else:
|
||||
pil_image = image.convert("RGB")
|
||||
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
|
||||
device=self.device, dtype=torch.bfloat16
|
||||
)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
return (tile_tensors - mean) / std
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||
pixel_values_list = []
|
||||
batch_num_tiles_list: list[list[int]] = []
|
||||
|
||||
for image_tensors in image_tensors_batch:
|
||||
num_tiles_list: list[int] = []
|
||||
for image in image_tensors:
|
||||
tiles = self._preprocess_single_image(image)
|
||||
pixel_values_list.append(tiles)
|
||||
num_tiles_list.append(int(tiles.shape[0]))
|
||||
batch_num_tiles_list.append(num_tiles_list)
|
||||
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
pixel_values = torch.empty(
|
||||
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
return pixel_values, batch_num_tiles_list
|
||||
|
||||
def _build_multimodal_prompts(
|
||||
self,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
text_prompts: Sequence[str],
|
||||
) -> list[str]:
|
||||
prompts = []
|
||||
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||
prompt_segments = []
|
||||
for i, tile_count in enumerate(num_tiles_list):
|
||||
token_count = self.model.num_image_token * tile_count
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||
return prompts
|
||||
|
||||
def _prepare_and_fuse_embeddings(
|
||||
self,
|
||||
prompts: Sequence[str],
|
||||
vit_embeds: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
|
||||
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
|
||||
if true_sequence_length > self.max_text_length:
|
||||
logger.warning(
|
||||
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
|
||||
self.max_text_length,
|
||||
true_sequence_length,
|
||||
)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
|
||||
img_token_mask = input_ids == self.img_context_token_id
|
||||
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
|
||||
|
||||
batch_size, _, channels = input_embeds.shape
|
||||
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
tokens_per_tile = self.model.num_image_token
|
||||
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
|
||||
|
||||
vit_idx = 0
|
||||
for batch_index in range(batch_size):
|
||||
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
|
||||
mask_b = img_token_mask[batch_index]
|
||||
actual_vis_tokens = actual_vis_tokens_list[batch_index]
|
||||
|
||||
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
|
||||
vit_idx += expected_vis_tokens
|
||||
if actual_vis_tokens > 0:
|
||||
if item_vit_embeds.shape[0] < actual_vis_tokens:
|
||||
raise ValueError(
|
||||
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
|
||||
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
|
||||
)
|
||||
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
|
||||
|
||||
current_token_idx = 0
|
||||
img_token_locations = torch.where(mask_b)[0]
|
||||
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||
if not bool(image_masks[batch_index, image_index].item()):
|
||||
start_offset = current_token_idx
|
||||
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||
if start_offset < end_offset:
|
||||
idxs = img_token_locations[start_offset:end_offset]
|
||||
attention_mask[batch_index, idxs] = 0
|
||||
current_token_idx += num_tokens_for_image
|
||||
|
||||
return input_embeds, attention_mask
|
||||
|
||||
def get_fused_image_text_embedding_from_tensor_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
image_masks: torch.Tensor,
|
||||
text_prompts: Sequence[str],
|
||||
return_cls_only: bool = True,
|
||||
):
|
||||
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
|
||||
if pixel_values.shape[0] == 0:
|
||||
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(self.model.language_model, "config"):
|
||||
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
|
||||
return empty
|
||||
|
||||
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||
vit_embeds = self.model.extract_feature(pixel_values)
|
||||
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
|
||||
prompts,
|
||||
vit_embeds,
|
||||
image_masks.to(device=self.device),
|
||||
batch_num_tiles_list,
|
||||
)
|
||||
|
||||
outputs = self.model.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.model.parameters()).device
|
||||
@@ -1,450 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.policies.evo1.evo1_model import EVO1
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class EVO1Policy(PreTrainedPolicy):
|
||||
config_class = Evo1Config
|
||||
name = "evo1"
|
||||
|
||||
def __init__(self, config: Evo1Config, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
if len(config.image_features) > config.max_views:
|
||||
raise ValueError(
|
||||
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model = EVO1(self._build_model_config(config))
|
||||
self.model.set_finetune_flags()
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if strict is None:
|
||||
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_model_config(config: Evo1Config) -> dict:
|
||||
return {
|
||||
"device": config.device,
|
||||
"return_cls_only": config.return_cls_only,
|
||||
"vlm_name": config.vlm_model_name,
|
||||
"vlm_num_layers": config.vlm_num_layers,
|
||||
"vlm_dtype": config.vlm_dtype,
|
||||
"use_flash_attn": config.use_flash_attn,
|
||||
"action_head": config.action_head,
|
||||
"action_horizon": config.chunk_size,
|
||||
"per_action_dim": config.max_action_dim,
|
||||
"state_dim": config.max_state_dim,
|
||||
"embed_dim": config.embed_dim,
|
||||
"hidden_dim": config.hidden_dim,
|
||||
"state_hidden_dim": config.state_hidden_dim,
|
||||
"num_heads": config.num_heads,
|
||||
"num_layers": config.num_layers,
|
||||
"dropout": config.dropout,
|
||||
"num_inference_timesteps": config.num_inference_timesteps,
|
||||
"num_categories": config.num_categories,
|
||||
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
|
||||
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
|
||||
"finetune_vlm": config.finetune_vlm,
|
||||
"finetune_language_model": config.finetune_language_model,
|
||||
"finetune_vision_model": config.finetune_vision_model,
|
||||
"finetune_action_head": config.finetune_action_head,
|
||||
}
|
||||
|
||||
@property
|
||||
def _camera_keys(self) -> list[str]:
|
||||
return list(self.config.image_features)
|
||||
|
||||
@property
|
||||
def _env_action_dim(self) -> int:
|
||||
action_feature = self.config.action_feature
|
||||
if action_feature is None:
|
||||
return self.config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
@property
|
||||
def _compute_dtype(self) -> torch.dtype:
|
||||
return next(self.model.action_head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def _training_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda"):
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
@property
|
||||
def _inference_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda") and self.config.use_amp:
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
def get_optim_params(self) -> list[dict]:
|
||||
decay, no_decay = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
is_bias = name.endswith("bias") or ".bias" in name
|
||||
is_norm = param.dim() == 1 or "norm" in name.lower()
|
||||
if is_bias or is_norm:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
|
||||
{"params": no_decay, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
prompts = batch.get(self.config.task_field)
|
||||
if prompts is None and self.config.task_field != "task":
|
||||
prompts = batch.get("task")
|
||||
if prompts is None:
|
||||
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
|
||||
if isinstance(prompts, str):
|
||||
return [prompts]
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [str(prompt) for prompt in prompts]
|
||||
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if OBS_STATE not in batch:
|
||||
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
|
||||
state = batch[OBS_STATE]
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
elif state.dim() == 3:
|
||||
state = state[:, -1]
|
||||
elif state.dim() != 2:
|
||||
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
|
||||
batch_size, state_dim = state.shape
|
||||
if state_dim > self.config.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("state_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(0)
|
||||
elif explicit_mask.dim() == 3:
|
||||
explicit_mask = explicit_mask[:, -1]
|
||||
elif explicit_mask.dim() != 2:
|
||||
raise ValueError(
|
||||
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, state_dim):
|
||||
raise ValueError(
|
||||
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=state.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :state_dim] = state.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :state_dim] = True
|
||||
else:
|
||||
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if ACTION not in batch:
|
||||
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
|
||||
action = batch[ACTION]
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1)
|
||||
batch_size, horizon, action_dim = action.shape
|
||||
if horizon != self.config.chunk_size:
|
||||
raise ValueError(
|
||||
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
|
||||
)
|
||||
if action_dim > self.config.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("action_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 2:
|
||||
if horizon == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
|
||||
)
|
||||
elif explicit_mask.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, horizon, action_dim):
|
||||
raise ValueError(
|
||||
"action_mask shape "
|
||||
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=action.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :, :action_dim] = action.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :, :action_dim] = True
|
||||
else:
|
||||
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
mask[:, : self._env_action_dim] = True
|
||||
return mask
|
||||
|
||||
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
|
||||
embodiment_ids = batch.get("embodiment_id")
|
||||
if embodiment_ids is None and self.config.embodiment_id_field:
|
||||
embodiment_ids = batch.get(self.config.embodiment_id_field)
|
||||
if embodiment_ids is None:
|
||||
return torch.full(
|
||||
(batch_size,),
|
||||
self.config.default_embodiment_id,
|
||||
dtype=torch.long,
|
||||
device=self.config.device,
|
||||
)
|
||||
if embodiment_ids.dim() == 0:
|
||||
embodiment_ids = embodiment_ids.unsqueeze(0)
|
||||
elif embodiment_ids.dim() > 1:
|
||||
embodiment_ids = embodiment_ids[:, -1]
|
||||
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
|
||||
|
||||
@property
|
||||
def _tracks_vlm_gradients(self) -> bool:
|
||||
return bool(
|
||||
self.config.finetune_vlm
|
||||
or self.config.finetune_language_model
|
||||
or self.config.finetune_vision_model
|
||||
)
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||
|
||||
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
|
||||
# from a real batch dim and not from C in the unbatched (C, H, W) case.
|
||||
normalized: dict[str, Tensor] = {}
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
normalized[camera_key] = image
|
||||
|
||||
batch_size = normalized[camera_keys[0]].shape[0]
|
||||
image_batches: list[list[Tensor]] = []
|
||||
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
sample_images: list[Tensor] = []
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||
if not sample_images:
|
||||
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||
while len(sample_images) < self.config.max_views:
|
||||
sample_images.append(torch.zeros_like(sample_images[0]))
|
||||
image_batches.append(sample_images[: self.config.max_views])
|
||||
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
|
||||
|
||||
return image_batches, image_masks
|
||||
|
||||
def _compute_fused_tokens(
|
||||
self,
|
||||
prompts: list[str],
|
||||
image_batches: list[list[Tensor]],
|
||||
image_masks: Tensor,
|
||||
) -> Tensor:
|
||||
track_vlm_gradients = self._tracks_vlm_gradients
|
||||
grad_context = nullcontext() if track_vlm_gradients else torch.no_grad()
|
||||
embedder = getattr(self.model, "embedder", None)
|
||||
embedder_was_training = embedder.training if embedder is not None else None
|
||||
|
||||
if not track_vlm_gradients and embedder is not None:
|
||||
embedder.eval()
|
||||
|
||||
try:
|
||||
with grad_context:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
finally:
|
||||
if not track_vlm_gradients and embedder is not None and embedder_was_training is not None:
|
||||
embedder.train(embedder_was_training)
|
||||
|
||||
if not track_vlm_gradients:
|
||||
fused_tokens = fused_tokens.detach()
|
||||
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
|
||||
|
||||
def _compute_masked_loss(
|
||||
self,
|
||||
pred_velocity: Tensor,
|
||||
target_velocity: Tensor,
|
||||
action_mask: Tensor,
|
||||
reduction: str,
|
||||
) -> Tensor:
|
||||
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
|
||||
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
|
||||
active = flat_mask.sum(dim=1).clamp_min(1.0)
|
||||
per_sample_loss = sq_error.sum(dim=1) / active
|
||||
if reduction == "none":
|
||||
return per_sample_loss
|
||||
if reduction != "mean":
|
||||
raise ValueError(f"Unsupported reduction '{reduction}'")
|
||||
return sq_error.sum() / active.sum()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
actions_gt, action_mask = self._prepare_actions(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._training_compute_dtype)
|
||||
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
|
||||
pred_velocity, noise = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
|
||||
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
|
||||
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
|
||||
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
|
||||
return loss, {
|
||||
"loss": loss_mean,
|
||||
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._inference_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||
|
||||
with (
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
if self.config.use_amp and str(self.config.device).startswith("cuda")
|
||||
else nullcontext()
|
||||
):
|
||||
actions = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions[:, :, : self._env_action_dim]
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(action_chunk.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
@@ -1,106 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
DONE,
|
||||
INFO,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
REWARD,
|
||||
TRUNCATED,
|
||||
)
|
||||
|
||||
|
||||
def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||
transition = batch_to_transition(batch)
|
||||
complementary_data = dict(transition.get("complementary_data") or {})
|
||||
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
|
||||
for key, value in batch.items():
|
||||
if key in reserved or key.startswith(OBS_PREFIX):
|
||||
continue
|
||||
complementary_data.setdefault(key, value)
|
||||
return create_transition(
|
||||
observation=transition.get("observation"),
|
||||
action=transition.get("action"),
|
||||
reward=transition.get("reward", 0.0),
|
||||
done=transition.get("done", False),
|
||||
truncated=transition.get("truncated", False),
|
||||
info=transition.get("info", {}),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
|
||||
def make_evo1_pre_post_processors(
|
||||
config: Evo1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=evo1_batch_to_transition,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -47,16 +47,17 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -89,7 +90,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x",
|
||||
"molmoact2".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -128,10 +130,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
return SACPolicy
|
||||
return GaussianActorPolicy
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
@@ -152,10 +154,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "evo1":
|
||||
from .evo1.modeling_evo1 import EVO1Policy
|
||||
elif name == "molmoact2":
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return EVO1Policy
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -172,8 +178,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "wall_x", "eo1", "evo1".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -196,8 +202,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -208,8 +214,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "evo1":
|
||||
return Evo1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -238,6 +246,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -271,26 +280,22 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
return make_groot_pre_post_processors_from_pretrained(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
postprocessor_config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -372,10 +377,10 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -413,6 +418,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -420,10 +426,20 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
processors = make_evo1_pre_post_processors(
|
||||
elif isinstance(policy_cfg, MolmoAct2Config):
|
||||
from .molmoact2.processor_molmoact2 import make_molmoact2_pre_post_processors
|
||||
|
||||
processors = make_molmoact2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -513,6 +529,10 @@ def make_policy(
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
if ds_meta is not None:
|
||||
set_dataset_feature_metadata = getattr(cfg, "set_dataset_feature_metadata", None)
|
||||
if callable(set_dataset_feature_metadata):
|
||||
set_dataset_feature_metadata(ds_meta.features)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
|
||||
+4
-4
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
+37
-59
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Encoder architecture
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
|
||||
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
# Network architecture
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters (Gaussian head)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
# Any validation specific to GaussianActor configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
|
||||
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
|
||||
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
|
||||
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
|
||||
# this preset.
|
||||
default_lr = 3e-4
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
"actor": {"lr": default_lr},
|
||||
"critic": {"lr": default_lr},
|
||||
"temperature": {"lr": default_lr},
|
||||
},
|
||||
)
|
||||
|
||||
+48
-447
@@ -15,16 +15,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
config: GaussianActorConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -54,9 +49,8 @@ class SACPolicy(
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
self._init_discrete_critic()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
@@ -65,11 +59,7 @@ class SACPolicy(
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
@@ -79,7 +69,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -92,360 +84,43 @@ class SACPolicy(
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
if self.discrete_critic is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
else:
|
||||
discrete_action = torch.ones(
|
||||
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||
"""Actor forward pass: sample actions and return log-probabilities.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
batch: A flat observation dict, or a training dict containing
|
||||
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||
(pre-computed encoder features).
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
"""Initialize policy actor network."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
@@ -455,21 +130,25 @@ class SACPolicy(
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
def _init_discrete_critic(self) -> None:
|
||||
"""Initialize discrete critic network."""
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -677,84 +356,6 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class DiscreteCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -811,7 +412,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -885,7 +486,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
+5
-5
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -18,4 +18,12 @@ from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"GR00TN17", "GR00TN17Config"}:
|
||||
from .groot_n1_7 import GR00TN17, GR00TN17Config
|
||||
|
||||
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -42,6 +43,9 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -181,8 +185,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
@@ -266,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -318,6 +321,71 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class AlternateVLDiT(DiT):
|
||||
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
|
||||
|
||||
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attend_text_every_n_blocks = attend_text_every_n_blocks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
image_mask: torch.Tensor | None = None,
|
||||
backbone_attention_mask: torch.Tensor | None = None,
|
||||
):
|
||||
if image_mask is None:
|
||||
raise ValueError("image_mask is required for AlternateVLDiT.")
|
||||
if backbone_attention_mask is None:
|
||||
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
|
||||
|
||||
temb = self.timestep_encoder(timestep)
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
image_attention_mask = image_mask & backbone_attention_mask
|
||||
non_image_attention_mask = (~image_mask) & backbone_attention_mask
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
if not self.config.interleave_self_attention:
|
||||
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
curr_encoder_attention_mask = (
|
||||
non_image_attention_mask
|
||||
if idx % (2 * self.attend_text_every_n_blocks) == 0
|
||||
else image_attention_mask
|
||||
)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=curr_encoder_attention_mask,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@@ -362,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from .action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
if self._beta_dist is None:
|
||||
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -14,12 +14,327 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
"n1_7": GROOT_N1_7,
|
||||
"n1d7": GROOT_N1_7,
|
||||
"n17": GROOT_N1_7,
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
}
|
||||
|
||||
|
||||
def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
|
||||
if transform is None:
|
||||
return None
|
||||
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
|
||||
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
|
||||
supported = ", ".join(
|
||||
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
|
||||
f"Supported transforms: none, {supported}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_groot_model_version(model_path: str | None) -> str | None:
|
||||
if not model_path:
|
||||
return None
|
||||
model_path_lower = model_path.lower()
|
||||
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
|
||||
return GROOT_N1_7
|
||||
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
|
||||
# N1.5 is unsupported, but it must still be recognised here to fail loudly
|
||||
# rather than silently treating an N1.5 checkpoint as N1.7.
|
||||
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
|
||||
return GROOT_N1_5
|
||||
config_version = _infer_groot_model_version_from_local_config(model_path)
|
||||
if config_version is not None:
|
||||
return config_version
|
||||
return None
|
||||
|
||||
|
||||
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
if model_path is None:
|
||||
return False
|
||||
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
if "libero_sim" in modality_configs:
|
||||
return "libero_sim"
|
||||
if len(modality_configs) == 1:
|
||||
return next(iter(modality_configs))
|
||||
return None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
return None
|
||||
modality_configs = processor_kwargs.get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag is None:
|
||||
return None
|
||||
|
||||
embodiment_config = modality_configs.get(embodiment_tag, {})
|
||||
if not isinstance(embodiment_config, dict):
|
||||
return None
|
||||
action_config = embodiment_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return None
|
||||
delta_indices = action_config.get("delta_indices", [])
|
||||
if not isinstance(delta_indices, list):
|
||||
return None
|
||||
return len(delta_indices) or None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_execution_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
|
||||
if action_horizon is None:
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag == "libero_sim":
|
||||
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
|
||||
# actions. Keeping that execution cadence avoids stale open-loop chunks.
|
||||
return min(action_horizon, 8)
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, str):
|
||||
continue
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
@@ -28,35 +343,44 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
max_state_dim: int = 132
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
max_action_dim: int = 132
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
# Groot-specific model parameters
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
base_model_path: str | None = None
|
||||
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -96,17 +420,16 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -117,6 +440,66 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
|
||||
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
|
||||
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
|
||||
# wrapper applies this conversion; mirror it here so eval on the
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
@@ -192,7 +575,10 @@ class GrootConfig(PreTrainedConfig):
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -1,503 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
@@ -1,395 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
@@ -1,542 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
if isinstance(image_inputs["pixel_values"], list):
|
||||
_pv = image_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
image_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
if isinstance(video_inputs["pixel_values"], list):
|
||||
_pv = video_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
video_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
|
||||
def _to_tensor(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
if v and isinstance(v[0], list):
|
||||
v = [t for sub in v for t in sub]
|
||||
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||
return torch.as_tensor(v)
|
||||
|
||||
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
@@ -1,374 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
from .action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from .utils import ensure_eagle_cache_ready
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
self.post_init()
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if ACTION in inputs:
|
||||
action = inputs[ACTION]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
@@ -0,0 +1,966 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
"model_name": "nvidia/Cosmos-Reason2-2B",
|
||||
"backbone_model_type": "qwen",
|
||||
"model_revision": None,
|
||||
"tune_top_llm_layers": 0,
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||
"shortest_image_edge": None,
|
||||
"crop_fraction": None,
|
||||
"random_rotation_angle": None,
|
||||
"color_jitter_params": None,
|
||||
"use_albumentations_transforms": True,
|
||||
"extra_augmentation_config": None,
|
||||
"formalize_language": True,
|
||||
"apply_sincos_state_encoding": False,
|
||||
"use_percentiles": True,
|
||||
"use_relative_action": False,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
"action_horizon": 40,
|
||||
"hidden_size": 1024,
|
||||
"input_embedding_dim": 1536,
|
||||
"state_history_length": 1,
|
||||
"add_pos_embed": True,
|
||||
"attn_dropout": 0.2,
|
||||
"use_vlln": True,
|
||||
"max_seq_len": 1024,
|
||||
"use_alternate_vl_dit": True,
|
||||
"attend_text_every_n_blocks": 2,
|
||||
"diffusion_model_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 48,
|
||||
"norm_type": "ada_norm",
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
"output_dim": 1024,
|
||||
"interleave_self_attention": True,
|
||||
},
|
||||
"vl_self_attention_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
},
|
||||
"num_inference_timesteps": 4,
|
||||
"noise_beta_alpha": 1.5,
|
||||
"noise_beta_beta": 1.0,
|
||||
"noise_s": 0.999,
|
||||
"num_timestep_buckets": 1000,
|
||||
"tune_projector": True,
|
||||
"tune_diffusion_model": True,
|
||||
"tune_vlln": True,
|
||||
"state_dropout_prob": 0.2,
|
||||
"exclude_state": False,
|
||||
"use_mean_std": False,
|
||||
"max_num_embodiments": 32,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
|
||||
|
||||
class GR00TN17Config(PretrainedConfig):
|
||||
"""Configuration for NVIDIA GR00T N1.7.
|
||||
|
||||
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
|
||||
flow-matching action head. This mirrors the public N1.7 checkpoint config
|
||||
while keeping it local to LeRobot and independent from the external
|
||||
Isaac-GR00T ``gr00t`` Python package.
|
||||
"""
|
||||
|
||||
model_type = "Gr00tN1d7"
|
||||
|
||||
_defaults = GR00T_N1_7_DEFAULTS
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
"""Two-layer MLP with category-specific weights."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
|
||||
|
||||
The frequency scalar is intentionally created on CPU and then broadcast with
|
||||
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
|
||||
embedding and avoids tiny dtype/device construction differences in parity
|
||||
tests.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
"""Action encoder with category-specific projections and sinusoidal time encoding."""
|
||||
|
||||
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
|
||||
super().__init__()
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, horizon, _ = actions.shape
|
||||
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,).")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
|
||||
action_emb = self.W1(actions, cat_ids)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
|
||||
return self.W3(x, cat_ids)
|
||||
|
||||
|
||||
class Qwen3Backbone(nn.Module):
|
||||
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
|
||||
|
||||
The public checkpoint stores the action head in the GR00T checkpoint but
|
||||
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
|
||||
keeps the nested HF module layout compatible across transformer versions
|
||||
and exposes the hidden states consumed by the action head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/Cosmos-Reason2-2B",
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
tune_top_llm_layers: int = 0,
|
||||
trainable_params_fp32: bool = False,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if use_flash_attention:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
|
||||
extra_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
except ImportError:
|
||||
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
|
||||
extra_kwargs["attn_implementation"] = "sdpa"
|
||||
if load_bf16:
|
||||
extra_kwargs["torch_dtype"] = torch.bfloat16
|
||||
|
||||
if load_pretrained_weights:
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
**extra_kwargs,
|
||||
**transformers_loading_kwargs,
|
||||
).eval()
|
||||
else:
|
||||
self.model = self._from_backbone_config(
|
||||
model_name=model_name,
|
||||
model_kwargs=extra_kwargs,
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
|
||||
if load_bf16 and trainable_params_fp32:
|
||||
for parameter in self.parameters():
|
||||
if parameter.requires_grad:
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
|
||||
) -> None:
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.visual.requires_grad_(False)
|
||||
if tune_top_llm_layers > 0:
|
||||
for layer in self.language_model.layers[-tune_top_llm_layers:]:
|
||||
for parameter in layer.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if self.language_model and not self.tune_llm:
|
||||
self.language_model.eval()
|
||||
if self.visual and not self.tune_visual:
|
||||
self.visual.eval()
|
||||
|
||||
@property
|
||||
def language_model(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).language_model
|
||||
|
||||
@property
|
||||
def visual(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).visual
|
||||
|
||||
def _from_backbone_config(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
config_kwargs: dict[str, Any],
|
||||
) -> nn.Module:
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
if "mm_token_type_ids" in model_input:
|
||||
return
|
||||
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
|
||||
return
|
||||
|
||||
input_ids = model_input.get("input_ids")
|
||||
if input_ids is None:
|
||||
return
|
||||
|
||||
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[input_ids == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[input_ids == video_token_id] = 2
|
||||
|
||||
model_input["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
|
||||
|
||||
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
|
||||
drops text position ids before calling text-layer flash attention. GR00T
|
||||
N1.7 was aligned against the older Transformers path, where a fourth text
|
||||
position row is forwarded alongside the temporal/height/width rows. Adding
|
||||
the row here preserves the newer multimodal position computation while
|
||||
keeping flash attention on the legacy code path.
|
||||
"""
|
||||
|
||||
if "position_ids" in model_input:
|
||||
return
|
||||
|
||||
qwen3_model = getattr(self.model, "model", self.model)
|
||||
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
|
||||
if compute_3d_position_ids is None:
|
||||
return
|
||||
|
||||
position_ids = compute_3d_position_ids(
|
||||
input_ids=model_input.get("input_ids"),
|
||||
image_grid_thw=model_input.get("image_grid_thw"),
|
||||
video_grid_thw=model_input.get("video_grid_thw"),
|
||||
inputs_embeds=None,
|
||||
attention_mask=model_input.get("attention_mask"),
|
||||
past_key_values=None,
|
||||
mm_token_type_ids=model_input.get("mm_token_type_ids"),
|
||||
)
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
|
||||
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
|
||||
|
||||
model_input["position_ids"] = position_ids
|
||||
|
||||
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
|
||||
|
||||
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
|
||||
Newer releases expose the post-final-norm tensor there instead. Capturing
|
||||
the last decoder layer output directly keeps the N1.7 action head input
|
||||
stable across Transformers versions.
|
||||
"""
|
||||
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
|
||||
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
|
||||
if isinstance(output, torch.Tensor):
|
||||
captured["features"] = output
|
||||
elif isinstance(output, (tuple, list)) and output:
|
||||
captured["features"] = output[0]
|
||||
elif hasattr(output, "last_hidden_state"):
|
||||
captured["features"] = output.last_hidden_state
|
||||
|
||||
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
|
||||
try:
|
||||
outputs = self.model(**model_input, output_hidden_states=True)
|
||||
finally:
|
||||
hook.remove()
|
||||
|
||||
return captured.get("features", outputs.hidden_states[-1])
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
|
||||
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
|
||||
model_input = {key: vl_input[key] for key in keys_to_use}
|
||||
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
|
||||
self._ensure_mm_token_type_ids(model_input)
|
||||
self._ensure_legacy_qwen3_position_ids(model_input)
|
||||
features = self._last_decoder_layer_output(model_input)
|
||||
image_mask = model_input["input_ids"] == self.model.config.image_token_id
|
||||
attention_mask = model_input["attention_mask"] == 1
|
||||
return BatchFeature(
|
||||
data={
|
||||
"backbone_features": features,
|
||||
"backbone_attention_mask": attention_mask,
|
||||
"image_mask": image_mask,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GR00TN17ActionHead(nn.Module):
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: GR00TN17Config):
|
||||
require_package("diffusers", extra="groot")
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
if config.use_alternate_vl_dit:
|
||||
self.model = AlternateVLDiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
|
||||
)
|
||||
else:
|
||||
self.model = DiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
)
|
||||
|
||||
self.action_dim = config.max_action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim * config.state_history_length,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
|
||||
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
|
||||
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
|
||||
else:
|
||||
self.vl_self_attention = nn.Identity()
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
self.state_dropout_prob = config.state_dropout_prob
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
|
||||
) -> None:
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vlln = tune_vlln
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vlln:
|
||||
self.vlln.requires_grad_(False)
|
||||
self.vl_self_attention.requires_grad_(False)
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vlln:
|
||||
self.vlln.eval()
|
||||
self.vl_self_attention.eval()
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if self._beta_dist is None:
|
||||
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
|
||||
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
|
||||
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = self.vlln(backbone_output["backbone_features"])
|
||||
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
vl_embeds = backbone_output.backbone_features
|
||||
device = vl_embeds.device
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
if action_input.state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = action_input.state.view(action_input.state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, embodiment_id)
|
||||
|
||||
if self.training and self.state_dropout_prob > 0:
|
||||
do_dropout = (
|
||||
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
|
||||
)
|
||||
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
|
||||
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None]
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
data={
|
||||
"loss": loss,
|
||||
"action_loss": action_loss,
|
||||
"action_mask": action_mask,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
state = action_input.state
|
||||
if state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = state.view(state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, action_input.embodiment_id)
|
||||
return BatchFeature(
|
||||
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action_with_features(
|
||||
self,
|
||||
backbone_features: torch.Tensor,
|
||||
state_features: torch.Tensor,
|
||||
embodiment_id: torch.Tensor,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
vl_embeds = backbone_features
|
||||
batch_size = vl_embeds.shape[0]
|
||||
device = vl_embeds.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.action_dim),
|
||||
dtype=vl_embeds.dtype,
|
||||
device=device,
|
||||
)
|
||||
dt = 1.0 / self.num_inference_timesteps
|
||||
vel_strength = torch.ones_like(actions)
|
||||
|
||||
if "action" in action_input:
|
||||
if options is None:
|
||||
raise ValueError("RTC options are required when action is provided to get_action.")
|
||||
action_horizon_before_padding = options["action_horizon"]
|
||||
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
|
||||
:,
|
||||
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
|
||||
:,
|
||||
]
|
||||
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
|
||||
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
|
||||
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
|
||||
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
|
||||
ramp = ramp / ramp[-1].clamp_min(1e-8)
|
||||
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
|
||||
None, :, None
|
||||
].to(device)
|
||||
|
||||
for t_step in range(self.num_inference_timesteps):
|
||||
t_cont = t_step / float(self.num_inference_timesteps)
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"action_pred": actions,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(
|
||||
self,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
features = self._encode_features(backbone_output, action_input)
|
||||
return self.get_action_with_features(
|
||||
backbone_features=features.backbone_features,
|
||||
state_features=features.state_features,
|
||||
embodiment_id=action_input.embodiment_id,
|
||||
backbone_output=backbone_output,
|
||||
action_input=action_input,
|
||||
options=options,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
|
||||
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=True,
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"tie_word_embeddings": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [5, 11, 17],
|
||||
"depth": 24,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1024,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 2048,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
|
||||
class GR00TN17(PreTrainedModel):
|
||||
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
|
||||
|
||||
config_class = GR00TN17Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN17Config,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
backbone_cls = get_backbone_cls(config)
|
||||
self.backbone = backbone_cls(
|
||||
model_name=config.model_name,
|
||||
tune_llm=config.tune_llm,
|
||||
tune_visual=config.tune_visual,
|
||||
select_layer=config.select_layer,
|
||||
reproject_vision=config.reproject_vision,
|
||||
use_flash_attention=config.use_flash_attention,
|
||||
load_bf16=config.load_bf16,
|
||||
tune_top_llm_layers=config.tune_top_llm_layers,
|
||||
trainable_params_fp32=config.backbone_trainable_params_fp32,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_pretrained_weights=load_backbone_weights,
|
||||
)
|
||||
self.action_head = GR00TN17ActionHead(config)
|
||||
self.post_init()
|
||||
|
||||
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
|
||||
global tree
|
||||
if tree is None:
|
||||
require_package("dm-tree", extra="groot", import_name="tree")
|
||||
tree = importlib.import_module("tree")
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_dtype(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
return x.to(self.device, dtype=self.dtype)
|
||||
return x.to(self.device)
|
||||
|
||||
return (
|
||||
tree.map_structure(to_device_with_dtype, backbone_inputs),
|
||||
tree.map_structure(to_device_with_dtype, action_inputs),
|
||||
)
|
||||
|
||||
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head(backbone_outputs, action_inputs)
|
||||
|
||||
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head.get_action(backbone_outputs, action_inputs, options)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
tune_vlln = kwargs.pop("tune_vlln", True)
|
||||
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
try:
|
||||
local_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
repo_type="model",
|
||||
revision=kwargs.get("revision"),
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
token=kwargs.get("token"),
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_backbone_weights=load_backbone_weights,
|
||||
**kwargs,
|
||||
)
|
||||
pretrained_model.backbone.set_trainable_parameters(
|
||||
tune_visual=tune_visual,
|
||||
tune_llm=tune_llm,
|
||||
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
|
||||
)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector,
|
||||
tune_diffusion_model=tune_diffusion_model,
|
||||
tune_vlln=tune_vlln,
|
||||
)
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
|
||||
try:
|
||||
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
@@ -17,22 +17,13 @@
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -46,8 +37,19 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_groot import GrootConfig
|
||||
from .groot_n1 import GR00TN15
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
@@ -67,37 +69,35 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
self._action_queue_steps = self._resolve_action_queue_steps()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
model = GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -118,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
1. Base GR00T N1.7 models - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
@@ -141,9 +141,15 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -174,7 +180,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -190,11 +196,15 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
@@ -215,6 +225,16 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -225,21 +245,160 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
execution_horizon = infer_groot_n1_7_action_execution_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
horizons = [n_action_steps]
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
if execution_horizon is not None:
|
||||
horizons.append(execution_horizon)
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
|
||||
horizons = [actions.shape[1]]
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
|
||||
for horizon in (self.config.chunk_size, self.config.n_action_steps):
|
||||
horizon = int(horizon)
|
||||
if horizon > 0:
|
||||
horizons.append(horizon)
|
||||
|
||||
return max(1, min(horizons))
|
||||
|
||||
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
if include_action:
|
||||
allowed_base.update({"action", "action_mask"})
|
||||
|
||||
allowed_base.update(
|
||||
{
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
}
|
||||
)
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
self,
|
||||
inputs: dict[str, Tensor],
|
||||
*,
|
||||
inference_delay: object,
|
||||
prev_chunk_left_over: object,
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
return inputs, None
|
||||
|
||||
prev_actions = prev_chunk_left_over
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
|
||||
batch_size = state.shape[0]
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
# provided prefix row as a real action constraint, so strip that padding
|
||||
# before constructing the native overlap options.
|
||||
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
|
||||
if valid_prefix_rows.any():
|
||||
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
|
||||
prev_actions = prev_actions[:, :valid_prefix_steps, :]
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
|
||||
action_horizon = int(prev_actions.shape[1])
|
||||
if action_horizon <= 0:
|
||||
return inputs, None
|
||||
|
||||
if prev_actions.shape[2] > max_action_dim:
|
||||
prev_actions = prev_actions[:, :, :max_action_dim]
|
||||
elif prev_actions.shape[2] < max_action_dim:
|
||||
pad = torch.zeros(
|
||||
prev_actions.shape[0],
|
||||
prev_actions.shape[1],
|
||||
max_action_dim - prev_actions.shape[2],
|
||||
dtype=prev_actions.dtype,
|
||||
device=prev_actions.device,
|
||||
)
|
||||
prev_actions = torch.cat([prev_actions, pad], dim=2)
|
||||
|
||||
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
|
||||
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
|
||||
overlap_steps = max(0, min(action_horizon, execution_horizon))
|
||||
if overlap_steps == 0:
|
||||
return inputs, None
|
||||
|
||||
try:
|
||||
frozen_steps = int(inference_delay or 0)
|
||||
except (TypeError, ValueError):
|
||||
frozen_steps = 0
|
||||
frozen_steps = max(0, min(frozen_steps, overlap_steps))
|
||||
|
||||
options = {
|
||||
"action_horizon": action_horizon,
|
||||
"rtc_overlap_steps": overlap_steps,
|
||||
"rtc_frozen_steps": frozen_steps,
|
||||
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
|
||||
}
|
||||
|
||||
inputs = dict(inputs)
|
||||
inputs["action"] = prev_actions
|
||||
return inputs, options
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
@@ -248,38 +407,54 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
|
||||
action-overlap options before calling the underlying model.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
if groot_options is not None:
|
||||
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
|
||||
else:
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
prediction_horizon = self._resolve_prediction_horizon(actions)
|
||||
actions = actions[:, :prediction_horizon]
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
@@ -292,40 +467,28 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,47 +0,0 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_molmoact2_README.md
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user