Compare commits

..

5 Commits

Author SHA1 Message Date
pepijn c5925399a9 style: remove decorative comment separator in transforms.py
Made-with: Cursor
2026-03-13 04:43:31 +00:00
pepijn f478ae5bfa docs: add Multi-Dataset Training guide
Covers feature mapping, auto-padding, per-dataset transforms,
weighted sampling, stats aggregation, and full config examples
for training across RoboCasa, LIBERO-plus, and RoboMME datasets.

Made-with: Cursor
2026-03-13 04:37:01 +00:00
pepijn b4d40d0228 feat: add MultiLeRobotDataset with weighted sampling and RoboMME env integration
Multi-dataset training support:
- NewMultiLeRobotDataset with per-dataset feature mapping, auto-padding,
  per-dataset transform pipelines, and weighted sampling
- MultiDatasetMeta shim compatible with EpisodeAwareSampler and make_policy
- WeightedEpisodeAwareSampler for proportional cross-dataset sampling
- SubDatasetConfig / MultiDatasetConfig in training configs
- DatasetTransformPipeline with built-in PadAction, PadState, ResizeImages
- Factory and training script wired up for multi-dataset path

RoboMME environment integration:
- RoboMMEEnv config and Gymnasium wrapper (robomme.py)
- robomme optional dependency in pyproject.toml

Made-with: Cursor
2026-03-13 04:31:35 +00:00
pepijn db5c26f07d feat(envs): add LIBERO-plus integration for evaluation benchmarks
Add LiberoPlusEnv config (subclass of LiberoEnv), register libero_plus
env type in factory, add import fallbacks for LIBERO-plus package
structure, and add libero_plus optional dependency group in pyproject.toml.

Made-with: Cursor
2026-03-12 04:31:09 +00:00
Pepijn 8904768db4 feat(envs): add RoboCasa composite-task benchmark integration
Integrates 5 selected RoboCasa kitchen tasks (3 short + 2 long) as a
LeRobot benchmark environment, following the same pattern as Libero.

Selected tasks:
  Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
  Long:  PrepareCoffee, RestockPantry

Changes:
- envs/robocasa.py: RoboCasaEnv wrapper with flat 12D Box action space,
  3-camera pixel obs, and 16D proprioceptive state
- envs/configs.py: RoboCasaEnv config with features_map
- envs/factory.py: wire robocasa into make_env + make_env_pre_post_processors
- processor/env_processor.py: RoboCasaProcessorStep for obs key remapping
- tests/test_robocasa_env.py: full test suite (auto-skips if assets missing)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-09 17:08:32 +01:00
27 changed files with 1762 additions and 1635 deletions
+2
View File
@@ -31,6 +31,8 @@
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
- local: multi_dataset_training
title: Multi-Dataset Training
title: "Datasets"
- sections:
- local: act
+1 -1
View File
@@ -165,7 +165,7 @@ hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
+232
View File
@@ -0,0 +1,232 @@
# Multi-Dataset Training
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
## Overview
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
`MultiLeRobotDataset` lets you train on all of them jointly by:
- **Mapping** each dataset's feature keys into a shared namespace
- **Padding** features that a dataset doesn't have with zeros
- **Weighting** how often each dataset is sampled
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
- **Aggregating** statistics across all sub-datasets for normalization
## Configuration
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
### SubDatasetConfig fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
| `root` | `str \| None` | `None` | Local root directory for the dataset |
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
| `revision` | `str \| None` | `None` | Dataset version / revision |
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
### Example: Three-dataset config
```yaml
dataset:
type: multi
use_imagenet_stats: true
datasets:
# RoboCasa: 3 cameras, state(16), action(12)
- repo_id: pepijn223/robocasa_PrepareCoffee
root: /data/robocasa_PrepareCoffee
weight: 1.0
feature_map:
observation.images.robot0_agentview_left: observation.images.front_left
observation.images.robot0_agentview_right: observation.images.front_right
observation.images.robot0_eye_in_hand: observation.images.wrist
# LIBERO-plus: 2 cameras, state(8), action(7)
- repo_id: pepijn223/libero_plus_lerobot
root: /data/libero_plus_lerobot
weight: 0.5
feature_map:
observation.images.front: observation.images.front_left
observation.images.wrist: observation.images.wrist
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
- repo_id: pepijn223/robomme_data_lerobot
root: /data/robomme_data_lerobot
weight: 0.3
feature_map:
image: observation.images.front_left
wrist_image: observation.images.wrist
state: observation.state
actions: action
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
```
## Feature Mapping
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
## Automatic Padding
When a sub-dataset doesn't have a feature that exists in the unified schema:
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
- **Float tensors** (state, action): padded with zeros
- **Integer/bool tensors**: padded with zeros / False
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
## Per-Dataset Transforms
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
### Built-in transforms
| Name | Description | Parameters |
|------|-------------|------------|
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
### Custom transforms
You can register your own transforms in `lerobot/datasets/transforms.py`:
```python
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
@register_dataset_transform("my_transform")
class MyTransform(DatasetTransformStep):
def __init__(self, some_param: int):
self.some_param = some_param
def __call__(self, sample: dict) -> dict:
# Modify sample in-place or return a new dict
sample["action"] = sample["action"] * self.some_param
return sample
```
Then reference it in the config:
```yaml
transforms:
- type: my_transform
kwargs: {some_param: 2}
```
## Weighted Sampling
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
## Stats Aggregation
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
## Training
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
```bash
python -m lerobot.scripts.lerobot_train \
--config_path path/to/multi_dataset_config.yaml
```
## Architecture
The data flow during training with `MultiLeRobotDataset`:
```
┌─────────────────────────────────────────────────────────┐
│ MultiLeRobotDataset.__getitem__(global_idx) │
│ │
│ 1. Map global_idx → (dataset_idx, local_idx) │
│ 2. Fetch sample from sub-dataset │
│ 3. Rename keys via feature_map │
│ 4. Apply per-dataset transforms (pad_action, etc.) │
│ 5. Zero-pad missing features + add _is_pad flags │
│ 6. Add dataset_index tag │
└─────────────────────┬───────────────────────────────────┘
┌────────────▼────────────┐
│ PyTorch DataLoader │
│ (collates into batch) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ LeRobot Preprocessor │
│ (normalize, tokenize) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ Policy forward + loss │
└─────────────────────────┘
```
## API Reference
### `NewMultiLeRobotDataset`
```python
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
dataset = NewMultiLeRobotDataset(
configs=[...], # list[SubDatasetConfig]
image_transforms=None, # optional image augmentation
delta_timestamps=None, # optional temporal neighbors
tolerance_s=1e-4, # timestamp tolerance
)
dataset.num_frames # total frames across all sub-datasets
dataset.num_episodes # total episodes
dataset.meta # MultiDatasetMeta (stats, features, episodes)
dataset.dataset_weights # list of per-dataset weights
dataset.features # unified feature dict (union of all mapped features)
dataset.camera_keys # unified camera key list
```
### `WeightedEpisodeAwareSampler`
```python
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
sampler = WeightedEpisodeAwareSampler(
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
shuffle=True,
)
```
### `DatasetTransformPipeline`
```python
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
pipeline = DatasetTransformPipeline([
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
])
sample = pipeline(sample) # modifies the sample dict
```
+43 -84
View File
@@ -12,59 +12,36 @@ The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train
## Part 1: Getting Started
### Install the Unitree SDK
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`:
### Install LeRobot on Your Machine
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
pip install -e .
cd ..
```
### Install LeRobot
```bash
conda install ffmpeg -c conda-forge
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
cd unitree_sdk2_python && pip install -e .
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### Test the Installation (Simulation)
The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main).
```bash
pip install mujoco loguru msgpack msgpack-numpy
```
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \
--display_data=true \
--robot.controller=GrootLocomotionController
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--display_data=true
```
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`.
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
- Press `9` to release the robot
- Press `7` / `8` to increase / decrease waist height
### Connect to the Physical Robot
### Connect to the Robot
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
@@ -82,11 +59,37 @@ ssh unitree@192.168.123.164
# Password: 123
```
### Share Internet via Ethernet
### Install LeRobot on the G1
The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet:
From the robot:
**On your laptop:**
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
---
## Part 2: Enable WiFi on the Robot
Wi-Fi connectivity is blocked by default on the G1. To activate:
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**On your laptop** (share internet via Ethernet):
```bash
sudo sysctl -w net.ipv4.ip_forward=1
@@ -97,7 +100,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
```
**On the G1:**
**On the G1** (set default route through your laptop):
```bash
sudo ip route del default 2>/dev/null || true
@@ -108,45 +111,6 @@ echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
ping -c 3 8.8.8.8
```
### Install the Unitree SDK on the G1
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation):
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
python -m pip install -e .
cd ..
```
### Install LeRobot on the G1
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
python -m pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### (Optional) Enable WiFi on the Robot
For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default):
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**Connect to a WiFi network:**
```bash
@@ -161,7 +125,7 @@ sudo nmcli connection up "YourNetwork"
ip a show wlan0
```
You can then SSH over WiFi instead of Ethernet:
You can now SSH over WiFi:
```bash
ssh unitree@<ROBOT_WIFI_IP>
@@ -170,23 +134,18 @@ ssh unitree@<ROBOT_WIFI_IP>
---
## Part 2: Teleoperation & Locomotion
## Part 3: Teleoperation & Locomotion
### Run the Robot Server
On the robot (from `~/lerobot`):
On the robot:
```bash
cd ~/lerobot
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
```
### Run the Locomotion Policy
You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network.
**From your laptop:**
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
@@ -199,13 +158,13 @@ lerobot-teleoperate \
--robot.controller=HolosomaLocomotionController
```
We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`.
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
---
## Part 3: Loco-Manipulation with the Homunculus Exoskeleton
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo).
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
### Calibrate
@@ -246,7 +205,7 @@ Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/dat
---
## Part 4: Training & Inference
## Part 5: Training & Inference
### Train
+12 -4
View File
@@ -76,7 +76,7 @@ dependencies = [
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.14.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"pynput>=1.7.8,<1.9.0",
@@ -119,13 +119,14 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"unitree-sdk2==1.0.1",
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"onnx>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -174,6 +175,14 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero_plus = [
"lerobot[transformers-dep]",
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
"lerobot[scipy-dep]",
]
robomme = [
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
@@ -222,7 +231,6 @@ lerobot-eval="lerobot.scripts.lerobot_eval:main"
lerobot-train="lerobot.scripts.lerobot_train:main"
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-dataset-subtask-annotate="lerobot.scripts.lerobot_subtask_annotate:main"
lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
+27 -6
View File
@@ -16,18 +16,13 @@
from dataclasses import dataclass, field
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@dataclass
class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
@@ -37,6 +32,32 @@ class DatasetConfig:
streaming: bool = False
@dataclass
class SubDatasetConfig:
"""Configuration for a single dataset within a MultiDatasetConfig."""
repo_id: str
root: str | None = None
episodes: list[int] | None = None
revision: str | None = None
video_backend: str = field(default_factory=get_safe_default_codec)
weight: float = 1.0
# Maps dataset-local feature keys to unified policy keys.
# Keys not listed pass through unchanged.
feature_map: dict[str, str] = field(default_factory=dict)
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
transforms: list[DatasetTransformStepConfig] | None = None
@dataclass
class MultiDatasetConfig:
"""Configuration for training on multiple datasets jointly."""
datasets: list[SubDatasetConfig] = field(default_factory=list)
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
use_imagenet_stats: bool = True
@dataclass
class WandBConfig:
enable: bool = False
+6 -6
View File
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
dataset: DatasetConfig | MultiDatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
@@ -129,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if isinstance(self.dataset, MultiDatasetConfig):
if len(self.dataset.datasets) < 1:
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@@ -143,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
@@ -1,2 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Data annotations for subtasks and VLM-based labeling.
@@ -1,671 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import subprocess
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
import cv2
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
create_subtask_index_array,
create_subtasks_dataframe,
save_subtasks,
)
if TYPE_CHECKING:
from lerobot.data_processing.data_annotations.vlm_annotations import BaseVLM
# Skill Annotation Data Structures
class Skill:
"""Represents a single atomic skill/subtask in a demonstration."""
def __init__(self, name: str, start: float, end: float):
self.name = name
self.start = start # Start timestamp in seconds
self.end = end # End timestamp in seconds
def to_dict(self) -> dict:
return {"name": self.name, "start": self.start, "end": self.end}
@classmethod
def from_dict(cls, data: dict) -> "Skill":
return cls(name=data["name"], start=data["start"], end=data["end"])
def __repr__(self) -> str:
return f"Skill(name='{self.name}', start={self.start:.2f}, end={self.end:.2f})"
class EpisodeSkills:
"""Container for all skills in an episode."""
def __init__(self, episode_index: int, description: str, skills: list[Skill]):
self.episode_index = episode_index
self.description = description
self.skills = skills
def to_dict(self) -> dict:
return {
"episode_index": self.episode_index,
"description": self.description,
"skills": [s.to_dict() for s in self.skills],
}
# Video Extraction Utilities
class VideoExtractor:
"""Utilities for extracting and processing video segments from LeRobot datasets."""
def __init__(self) -> None:
pass
def extract_episode_video(
self,
video_path: Path,
start_timestamp: float,
end_timestamp: float,
target_fps: int = 1,
) -> Path:
"""
Extract a specific episode segment from a concatenated video file.
Args:
video_path: Path to the source video file
start_timestamp: Start time in seconds
end_timestamp: End time in seconds
target_fps: Target frames per second for output
Returns:
Path to the extracted temporary video file
"""
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
tmp_path = Path(tmp_file.name)
duration = end_timestamp - start_timestamp
print(f"Extracting: {start_timestamp:.1f}s - {end_timestamp:.1f}s ({duration:.1f}s)")
cmd = [
"ffmpeg",
"-i",
str(video_path),
"-ss",
str(start_timestamp),
"-t",
str(duration),
"-r",
str(target_fps),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-an",
"-y",
str(tmp_path),
]
try:
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"FFmpeg failed: {e}") from e
except FileNotFoundError as e:
raise RuntimeError("FFmpeg not found. Please install ffmpeg.") from e
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("Video extraction produced invalid file")
return tmp_path
def add_timer_overlay(self, video_path: Path) -> Path:
"""
Add a visible timer overlay to each frame (elapsed time in seconds) in one corner.
Used so the VLM can read the timestamp from the image instead of relying on file metadata.
Draws a black box with white text at top-right. Writes to a new temporary file and returns its path.
"""
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file:
out_path = Path(out_file.name)
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError("Failed to open video")
fps = cap.get(cv2.CAP_PROP_FPS) or 1.0
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w, h))
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = max(1.2, min(h, w) / 350.0)
thickness = max(2, int(font_scale))
padding = 15
margin = 30
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
t_sec = frame_idx / fps
text = f"{t_sec:.2f} s"
(tw, th), baseline = cv2.getTextSize(text, font, font_scale, thickness)
# Top-right placement
x_text = w - tw - margin - padding
y_text = margin + th + padding
# Rectangle coordinates (black box behind text)
x1 = x_text - padding
y1 = y_text - th - padding
x2 = x_text + tw + padding
y2 = y_text + baseline + padding
# Draw black filled rectangle
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1)
# Draw white text
cv2.putText(
frame,
text,
(x_text, y_text),
font,
font_scale,
(255, 255, 255),
thickness,
lineType=cv2.LINE_AA,
)
writer.write(frame)
frame_idx += 1
cap.release()
writer.release()
if not out_path.exists() or out_path.stat().st_size < 1024:
if out_path.exists():
out_path.unlink()
raise RuntimeError("Timer overlay produced invalid file")
return out_path
def get_video_duration(self, video_path: Path) -> float:
"""Get duration of a video file in seconds."""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return frame_count / fps
# Skill Annotation Pipeline
class SkillAnnotator:
"""
Main class for annotating LeRobot datasets with skill labels.
This class orchestrates the full annotation pipeline:
1. Load dataset
2. Extract video segments for each episode
3. Run VLM-based skill segmentation
4. Update dataset task metadata
"""
def __init__(
self,
vlm: "BaseVLM",
video_extractor: VideoExtractor | None = None,
batch_size: int = 8,
add_timer_overlay: bool = True,
):
self.vlm = vlm
self.video_extractor = video_extractor or VideoExtractor()
self.batch_size = batch_size
self.add_timer_overlay = add_timer_overlay
def annotate_dataset(
self,
dataset: LeRobotDataset,
video_key: str,
episodes: list[int] | None = None,
skip_existing: bool = False,
subtask_labels: list[str] | None = None,
) -> dict[int, EpisodeSkills]:
"""
Annotate all episodes in a dataset with skill labels using batched processing.
Args:
dataset: LeRobot dataset to annotate
video_key: Key for video observations (e.g., "observation.images.base")
episodes: Specific episode indices to annotate (None = all)
skip_existing: Skip episodes that already have skill annotations
subtask_labels: If provided, model must choose only from these labels (closed vocabulary)
Returns:
Dictionary mapping episode index to EpisodeSkills
"""
episode_indices = episodes or list(range(dataset.meta.total_episodes))
annotations: dict[int, EpisodeSkills] = {}
failed_episodes: dict[int, str] = {} # Track failed episodes with error messages
# Get coarse task description if available
coarse_goal = self._get_coarse_goal(dataset)
# Filter out episodes that already have annotations if skip_existing is True
if skip_existing:
existing_annotations = load_skill_annotations(dataset.root)
if existing_annotations and "episodes" in existing_annotations:
# Only skip episodes that exist AND have non-empty skills
existing_episode_indices = set()
for idx_str, episode_data in existing_annotations["episodes"].items():
idx = int(idx_str)
# Check if skills list exists and is not empty
if "skills" in episode_data and episode_data["skills"]:
existing_episode_indices.add(idx)
original_count = len(episode_indices)
episode_indices = [ep for ep in episode_indices if ep not in existing_episode_indices]
skipped_count = original_count - len(episode_indices)
if skipped_count > 0:
print(f"Skipping {skipped_count} episodes with existing non-empty annotations")
if not episode_indices:
print("No episodes to annotate (all already annotated)")
return annotations
print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...")
# Process episodes in batches
for batch_start in range(0, len(episode_indices), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(episode_indices))
batch_episodes = episode_indices[batch_start:batch_end]
print(
f"Processing batch {batch_start // self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1) // self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})..."
)
try:
batch_annotations = self._annotate_episodes_batch(
dataset, batch_episodes, video_key, coarse_goal, subtask_labels
)
for ep_idx in batch_episodes:
if ep_idx in batch_annotations and batch_annotations[ep_idx]:
skills = batch_annotations[ep_idx]
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
print(f" Episode {ep_idx}: {len(skills)} skills identified")
else:
failed_episodes[ep_idx] = "Empty or missing skills from batch processing"
print(f"⚠ Episode {ep_idx}: No skills extracted, will retry")
except Exception as e:
print(f"✗ Batch failed: {e}. Falling back to single-episode processing...")
# Fallback: process episodes one by one
for ep_idx in batch_episodes:
try:
skills = self._annotate_episode(
dataset, ep_idx, video_key, coarse_goal, subtask_labels
)
if skills:
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
print(f" Episode {ep_idx}: {len(skills)} skills identified")
else:
failed_episodes[ep_idx] = "Empty skills list from single-episode processing"
print(f"⚠ Episode {ep_idx}: No skills extracted, will retry")
except Exception as ep_error:
failed_episodes[ep_idx] = str(ep_error)
print(f"⚠ Episode {ep_idx} failed: {ep_error}, will retry")
# Retry failed episodes one more time
if failed_episodes:
print(f"\nRetrying {len(failed_episodes)} failed episodes...")
retry_count = 0
for ep_idx, error_msg in list(failed_episodes.items()):
print(f"Retry attempt for episode {ep_idx} (previous error: {error_msg})")
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal, subtask_labels)
if skills:
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
print(f" Episode {ep_idx} (retry): {len(skills)} skills identified")
del failed_episodes[ep_idx]
retry_count += 1
else:
print(f"✗ Episode {ep_idx} (retry): Still no skills extracted")
except Exception as retry_error:
failed_episodes[ep_idx] = str(retry_error)
print(f"✗ Episode {ep_idx} (retry) failed: {retry_error}")
if retry_count > 0:
print(f"Successfully recovered {retry_count} episodes on retry")
if failed_episodes:
print(f"\n⚠ Warning: {len(failed_episodes)} episodes still failed after retry:")
for ep_idx, error_msg in failed_episodes.items():
print(f" Episode {ep_idx}: {error_msg}")
return annotations
def _get_coarse_goal(self, dataset: LeRobotDataset) -> str:
"""Extract or generate the coarse task description."""
# Try to get from existing task metadata
if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0:
# Get the first task description
first_task = dataset.meta.tasks.index[0]
if first_task:
return str(first_task)
return "Perform the demonstrated manipulation task."
def _annotate_episodes_batch(
self,
dataset: LeRobotDataset,
episode_indices: list[int],
video_key: str,
coarse_goal: str,
subtask_labels: list[str] | None = None,
) -> dict[int, list[Skill]]:
"""Annotate multiple episodes with skill labels in a batch."""
# Extract all videos for this batch
extracted_paths = []
timer_paths = []
paths_for_vlm = []
durations = []
valid_episode_indices = []
for ep_idx in episode_indices:
try:
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
print(f"Warning: Video not found for episode {ep_idx}")
continue
# Get episode timestamps from metadata
ep = dataset.meta.episodes[ep_idx]
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = end_ts - start_ts
# Extract episode segment to temporary file
extracted_path = self.video_extractor.extract_episode_video(
video_path, start_ts, end_ts, target_fps=dataset.meta.fps
)
if self.add_timer_overlay:
video_for_vlm = self.video_extractor.add_timer_overlay(extracted_path)
extracted_paths.append(extracted_path)
timer_paths.append(video_for_vlm)
else:
video_for_vlm = extracted_path
extracted_paths.append(extracted_path)
timer_paths.append(None)
paths_for_vlm.append(video_for_vlm)
durations.append(duration)
valid_episode_indices.append(ep_idx)
except Exception as e:
print(f"Warning: Failed to extract video for episode {ep_idx}: {e}")
continue
if not paths_for_vlm:
return {}
try:
# Run VLM skill segmentation in batch
all_skills = self.vlm.segment_skills_batch(paths_for_vlm, durations, coarse_goal, subtask_labels)
# Map results back to episode indices
results = {}
for ep_idx, skills in zip(valid_episode_indices, all_skills, strict=True):
results[ep_idx] = skills
return results
finally:
# Clean up all temporary files (extracted and timer-overlay)
for path in extracted_paths:
if path.exists():
path.unlink()
for path in timer_paths:
if path is not None and path.exists():
path.unlink()
def _annotate_episode(
self,
dataset: LeRobotDataset,
episode_index: int,
video_key: str,
coarse_goal: str,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Annotate a single episode with skill labels."""
# Get video path and timestamps for this episode
video_path = dataset.root / dataset.meta.get_video_file_path(episode_index, video_key)
if not video_path.exists():
raise FileNotFoundError(f"Video not found: {video_path}")
# Get episode timestamps from metadata
ep = dataset.meta.episodes[episode_index]
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = end_ts - start_ts
# Extract episode segment to temporary file
extracted_path = self.video_extractor.extract_episode_video(
video_path, start_ts, end_ts, target_fps=1
)
if self.add_timer_overlay:
video_for_vlm = self.video_extractor.add_timer_overlay(extracted_path)
else:
video_for_vlm = extracted_path
try:
# Run VLM skill segmentation
skills = self.vlm.segment_skills(video_for_vlm, duration, coarse_goal, subtask_labels)
return skills
finally:
# Clean up temporary files (extracted and optionally timer-overlay)
if extracted_path.exists():
extracted_path.unlink()
if self.add_timer_overlay and video_for_vlm != extracted_path and video_for_vlm.exists():
video_for_vlm.unlink()
# Metadata Writer - Updates per-frame task_index based on skills
def get_skill_for_timestamp(skills: list[Skill], timestamp: float) -> Skill | None:
"""
Find which skill covers a given timestamp.
Args:
skills: List of skills with start/end times
timestamp: Frame timestamp in seconds
Returns:
The Skill that covers this timestamp, or None if not found
"""
for skill in skills:
if skill.start <= timestamp < skill.end:
return skill
# Handle the last frame (end boundary)
if timestamp >= skill.end and skill == skills[-1]:
return skill
return skills[-1] if skills else None # Fallback to last skill
def save_skill_annotations(
dataset: LeRobotDataset,
annotations: dict[int, EpisodeSkills],
output_dir: Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""
Save skill annotations to the dataset by:
1. Creating a subtasks.parquet file with unique subtasks
2. Adding a subtask_index feature to the dataset
3. Saving raw skill annotations as JSON for reference
This function does NOT modify tasks.parquet - it keeps the original tasks intact
and creates a separate subtask hierarchy.
Args:
dataset: The LeRobot dataset to annotate
annotations: Dictionary of episode skills
output_dir: Optional directory to save the modified dataset
repo_id: Optional repository ID for the new dataset
Returns:
New dataset with subtask_index feature added
"""
if not annotations:
print("No annotations to save")
return dataset
# Step 1: Create subtasks DataFrame
print("Creating subtasks DataFrame...")
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
# Step 2: Create subtask_index array for all frames
print("Creating subtask_index array...")
subtask_indices = create_subtask_index_array(dataset, annotations, skill_to_subtask_idx)
# Step 3: Save subtasks.parquet to the original dataset root
save_subtasks(subtasks_df, dataset.root)
# Step 4: Save the raw skill annotations as JSON for reference
skills_path = dataset.root / "meta" / "skills.json"
skills_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing skills data if it exists and is not empty
existing_skills_data = None
if skills_path.exists():
try:
with open(skills_path) as f:
existing_skills_data = json.load(f)
if existing_skills_data and len(existing_skills_data.get("episodes", {})) > 0:
print(
f"Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging..."
)
except (OSError, json.JSONDecodeError):
print("Warning: Could not load existing skills.json, will create new file")
existing_skills_data = None
# Prepare new annotations
new_episodes = {str(ep_idx): ann.to_dict() for ep_idx, ann in annotations.items()}
# Merge with existing data if available
if existing_skills_data:
# Preserve existing episodes that are not being updated
merged_episodes = existing_skills_data.get("episodes", {}).copy()
merged_episodes.update(new_episodes)
# Merge skill_to_subtask_index mappings
merged_skill_to_subtask = existing_skills_data.get("skill_to_subtask_index", {}).copy()
merged_skill_to_subtask.update(skill_to_subtask_idx)
# Use existing coarse_description if available, otherwise use new one
coarse_desc = existing_skills_data.get(
"coarse_description", annotations[next(iter(annotations))].description
)
skills_data = {
"coarse_description": coarse_desc,
"skill_to_subtask_index": merged_skill_to_subtask,
"episodes": merged_episodes,
}
print(
f"Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}"
)
else:
# No existing data, create new
skills_data = {
"coarse_description": annotations[next(iter(annotations))].description,
"skill_to_subtask_index": skill_to_subtask_idx,
"episodes": new_episodes,
}
with open(skills_path, "w") as f:
json.dump(skills_data, f, indent=2)
print(f" Saved skill annotations to {skills_path}")
# Step 5: Add subtask_index feature to dataset using add_features
print("Adding subtask_index feature to dataset...")
# Determine output directory and repo_id
output_dir = dataset.root.parent / f"{dataset.root.name}" if output_dir is None else Path(output_dir)
if repo_id is None:
repo_id = f"{dataset.repo_id}"
# Add feature using dataset_tools
feature_info = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
new_dataset = add_features(
dataset=dataset,
features={
"subtask_index": (subtask_indices, feature_info),
},
output_dir=output_dir,
repo_id=repo_id,
)
# Copy subtasks.parquet to new output directory
import shutil
shutil.copy(dataset.root / "meta" / "subtasks.parquet", output_dir / "meta" / "subtasks.parquet")
shutil.copy(dataset.root / "meta" / "skills.json", output_dir / "meta" / "skills.json")
print(" Successfully added subtask_index feature!")
print(f" New dataset saved to: {new_dataset.root}")
print(f" Total subtasks: {len(subtasks_df)}")
return new_dataset
def load_skill_annotations(dataset_root: Path) -> dict | None:
"""Load existing skill annotations from a dataset."""
skills_path = dataset_root / "meta" / "skills.json"
if skills_path.exists():
with open(skills_path) as f:
return json.load(f)
return None
@@ -1,271 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
from abc import ABC, abstractmethod
from pathlib import Path
import torch
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
from lerobot.utils.constants import (
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
format_subtask_labels_section,
)
logger = logging.getLogger(__name__)
DEFAULT_MODEL = "Qwen/Qwen3.5-27B"
def create_skill_segmentation_prompt(
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
duration_seconds: float | None = None,
) -> str:
"""Create the prompt for skill segmentation using the template from constants."""
if duration_seconds is None:
raise ValueError("duration_seconds is required for skill segmentation prompt")
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
subtask_labels_section = format_subtask_labels_section(subtask_labels) if subtask_labels else ""
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
goal_context=goal_context,
subtask_labels_section=subtask_labels_section,
video_duration_seconds=duration_seconds,
video_duration_mm_ss=video_duration_mm_ss,
)
class BaseVLM(ABC):
"""
Abstract base class for Vision-Language Models used in skill segmentation.
To add a new VLM family:
1. Subclass BaseVLM
2. Implement __init__, segment_skills, and segment_skills_batch
3. Register it in get_vlm()
"""
@abstractmethod
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
pass
@abstractmethod
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment a single video into atomic skills."""
pass
@abstractmethod
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into atomic skills in a single batch."""
pass
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse JSON skill list from VLM response text."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
try:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError as e:
raise ValueError(f"Could not parse JSON from VLM response: {response[:200]}...") from e
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
class QwenVL(BaseVLM):
"""Qwen VL model for skill segmentation (default: Qwen3.5 series).
Uses qwen-vl-utils for video processing and the HuggingFace transformers
Qwen3VLProcessor pipeline. Requires transformers >= 5.4.0 for correct
video position embeddings.
"""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForImageTextToText, AutoProcessor
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
logger.info(f"Loading model: {model_name}...")
self.model = AutoModelForImageTextToText.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.processor.tokenizer.padding_side = "left"
logger.info(f"Model loaded on {device}")
def _build_messages(self, video_path: Path, episode_duration: float, prompt: str) -> list[dict]:
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
return [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
{
"type": "text",
"text": (
f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). "
f"Segment into atomic skills. Last skill must end at {episode_duration:.1f}."
),
},
],
},
]
def _prepare_inputs(self, messages: list[dict]) -> dict:
"""Tokenize a single message and return processor inputs on device."""
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
videos, video_metadata = None, None
if video_inputs:
videos = [v[0] for v in video_inputs]
video_metadata = [v[1] for v in video_inputs]
return self.processor(
text=[text],
images=image_inputs,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
padding=True,
return_tensors="pt",
).to(self.device)
def _decode(self, inputs, generated_ids) -> list[str]:
return self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
messages = self._build_messages(video_path, episode_duration, prompt)
inputs = self._prepare_inputs(messages)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self._decode(inputs, generated_ids)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
all_texts = []
all_video_tuples: list[tuple] = []
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
messages = self._build_messages(video_path, duration, prompt)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
_image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
all_texts.append(text)
all_video_tuples.extend(video_inputs or [])
videos, video_metadata = None, None
if all_video_tuples:
videos = [v[0] for v in all_video_tuples]
video_metadata = [v[1] for v in all_video_tuples]
inputs = self.processor(
text=all_texts,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self._decode(inputs, generated_ids)
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
logger.warning(f"No skills parsed for video {idx}")
all_skills.append(skills)
except Exception as e:
logger.warning(f"Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM:
"""Create a VLM instance. Defaults to QwenVL which supports the Qwen3.5 series."""
return QwenVL(model_name, device, torch_dtype)
+67 -51
View File
@@ -18,13 +18,14 @@ from pprint import pformat
import torch
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
return delta_timestamps
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
"""Create a single or multi-dataset depending on the config type.
Returns:
LeRobotDataset | MultiLeRobotDataset
LeRobotDataset | NewMultiLeRobotDataset
"""
if isinstance(cfg.dataset, MultiDatasetConfig):
return _make_multi_dataset(cfg)
return _make_single_dataset(cfg)
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
)
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
tolerance_s=cfg.tolerance_s,
)
else:
dataset = StreamingLeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
if not ds_cfg.streaming:
dataset = LeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
revision=ds_cfg.revision,
video_backend=ds_cfg.video_backend,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
else:
dataset = StreamingLeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=ds_cfg.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
if ds_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
)
dataset = NewMultiLeRobotDataset(
configs=multi_cfg.datasets,
image_transforms=image_transforms,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
len(multi_cfg.datasets),
pformat(
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
indent=2,
),
)
if multi_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
+364
View File
@@ -0,0 +1,364 @@
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
Supports:
- Per-dataset feature mapping (rename keys to a unified namespace)
- Automatic zero-padding for features missing in some datasets
- Per-dataset transform pipelines
- Weighted sampling via dataset weights
- Aggregated stats across all sub-datasets
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
import torch
import torch.utils.data
from lerobot.configs.default import SubDatasetConfig
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.transforms import DatasetTransformPipeline
class MultiDatasetMeta:
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
Built by aggregating the metadata of multiple sub-datasets after their
feature keys have been mapped to a unified namespace.
"""
def __init__(
self,
datasets: list[LeRobotDataset],
feature_maps: list[dict[str, str]],
):
self._datasets = datasets
self._feature_maps = feature_maps
self._unified_features = self._build_unified_features()
self._episodes = self._build_episodes()
self._stats = self._build_stats()
# ------------------------------------------------------------------
# Feature union
# ------------------------------------------------------------------
def _build_unified_features(self) -> dict[str, dict]:
"""Build feature dict as the *union* of all mapped feature keys."""
unified: dict[str, dict] = {}
for ds, fmap in zip(self._datasets, self._feature_maps):
for original_key, feat_info in ds.meta.features.items():
mapped_key = fmap.get(original_key, original_key)
if mapped_key not in unified:
unified[mapped_key] = dict(feat_info)
else:
existing_shape = tuple(unified[mapped_key]["shape"])
new_shape = tuple(feat_info["shape"])
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
logging.warning(
"Feature '%s' has shape %s in one dataset but %s in another. "
"The larger shape will be used (padding applied automatically).",
mapped_key,
existing_shape,
new_shape,
)
if np.prod(new_shape) > np.prod(existing_shape):
unified[mapped_key] = dict(feat_info)
return unified
# ------------------------------------------------------------------
# Episode metadata (global flat indexing)
# ------------------------------------------------------------------
def _build_episodes(self) -> dict[str, list]:
"""Concatenate episode boundaries across sub-datasets with frame offsets.
Produces the same column structure as ``load_episodes()`` so that
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
"""
from_indices: list[int] = []
to_indices: list[int] = []
dataset_source: list[int] = []
frame_offset = 0
for ds_idx, ds in enumerate(self._datasets):
eps = ds.meta.episodes
for ep in eps:
from_indices.append(ep["dataset_from_index"] + frame_offset)
to_indices.append(ep["dataset_to_index"] + frame_offset)
dataset_source.append(ds_idx)
frame_offset += ds.num_frames
return {
"dataset_from_index": from_indices,
"dataset_to_index": to_indices,
"dataset_source": dataset_source,
}
# ------------------------------------------------------------------
# Stats aggregation
# ------------------------------------------------------------------
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats across sub-datasets using mapped feature keys."""
mapped_stats_list: list[dict[str, dict]] = []
for ds, fmap in zip(self._datasets, self._feature_maps):
reverse_map = {v: k for k, v in fmap.items()}
mapped: dict[str, dict] = {}
for unified_key in self._unified_features:
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds.meta.stats:
mapped[unified_key] = ds.meta.stats[original_key]
mapped_stats_list.append(mapped)
return aggregate_stats(mapped_stats_list)
# ------------------------------------------------------------------
# Properties matching LeRobotDatasetMetadata API
# ------------------------------------------------------------------
@property
def features(self) -> dict[str, dict]:
return self._unified_features
@property
def image_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
@property
def names(self) -> dict[str, list | dict]:
return {k: f["names"] for k, f in self._unified_features.items()}
@property
def shapes(self) -> dict[str, tuple]:
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
@property
def fps(self) -> int:
fps_values = {ds.meta.fps for ds in self._datasets}
if len(fps_values) > 1:
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
return self._datasets[0].meta.fps
@property
def stats(self) -> dict[str, dict[str, np.ndarray]]:
return self._stats
@stats.setter
def stats(self, value: dict):
self._stats = value
@property
def episodes(self) -> dict[str, list]:
return self._episodes
@property
def total_episodes(self) -> int:
return sum(ds.meta.total_episodes for ds in self._datasets)
@property
def total_frames(self) -> int:
return sum(ds.meta.total_frames for ds in self._datasets)
@property
def total_tasks(self) -> int:
return sum(ds.meta.total_tasks for ds in self._datasets)
@property
def info(self) -> dict:
return {
"fps": self.fps,
"features": self._unified_features,
"total_episodes": self.total_episodes,
"total_frames": self.total_frames,
"total_tasks": self.total_tasks,
"codebase_version": "v3.0",
}
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
Each sub-dataset can have different feature names and shapes. A per-dataset
``feature_map`` renames keys into a shared namespace. Features that a given
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
the full unified feature set.
"""
def __init__(
self,
configs: list[SubDatasetConfig],
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
):
super().__init__()
self._configs = configs
self.image_transforms = image_transforms
self._datasets: list[LeRobotDataset] = []
self._feature_maps: list[dict[str, str]] = []
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
self._weights: list[float] = []
for cfg in configs:
ds = LeRobotDataset(
repo_id=cfg.repo_id,
root=cfg.root,
episodes=cfg.episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
revision=cfg.revision,
video_backend=cfg.video_backend,
)
self._datasets.append(ds)
self._feature_maps.append(cfg.feature_map or {})
self._transform_pipelines.append(
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
)
self._weights.append(cfg.weight)
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
# Pre-compute cumulative frame counts for fast index mapping.
self._cumulative_frames: list[int] = []
total = 0
for ds in self._datasets:
total += ds.num_frames
self._cumulative_frames.append(total)
# Build reverse maps (unified_key -> original_key) per dataset for padding.
self._reverse_maps: list[dict[str, str]] = []
for fmap in self._feature_maps:
self._reverse_maps.append({v: k for k, v in fmap.items()})
logging.info(
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
"%d unified features",
len(self._datasets),
self.num_frames,
self.num_episodes,
len(self._meta.features),
)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
@property
def meta(self) -> MultiDatasetMeta:
return self._meta
@property
def dataset_weights(self) -> list[float]:
return self._weights
@property
def num_frames(self) -> int:
return self._cumulative_frames[-1] if self._cumulative_frames else 0
@property
def num_episodes(self) -> int:
return sum(ds.num_episodes for ds in self._datasets)
@property
def episodes(self) -> list[int] | None:
return None
@property
def fps(self) -> int:
return self._meta.fps
@property
def features(self) -> dict[str, dict]:
return self._meta.features
@property
def camera_keys(self) -> list[str]:
return self._meta.camera_keys
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
def _locate(self, idx: int) -> tuple[int, int]:
"""Map a global frame index to (dataset_index, local_index)."""
for ds_idx, cum in enumerate(self._cumulative_frames):
if idx < cum:
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
return ds_idx, local
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
def __len__(self) -> int:
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
ds_idx, local_idx = self._locate(idx)
item = self._datasets[ds_idx][local_idx]
# 1. Rename keys according to feature_map.
fmap = self._feature_maps[ds_idx]
if fmap:
renamed: dict[str, torch.Tensor] = {}
for key, value in item.items():
renamed[fmap.get(key, key)] = value
item = renamed
# 2. Apply per-dataset transform pipeline.
pipeline = self._transform_pipelines[ds_idx]
if pipeline is not None:
item = pipeline(item)
# 3. Pad missing features with zeros.
reverse_map = self._reverse_maps[ds_idx]
ds_features = self._datasets[ds_idx].meta.features
for unified_key, feat_info in self._meta.features.items():
if unified_key in item:
continue
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds_features:
continue
shape = tuple(feat_info["shape"])
dtype = feat_info["dtype"]
if dtype in ("video", "image"):
# Camera tensors are (C, H, W) after transforms.
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
elif dtype in ("float32", "float64"):
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
elif dtype in ("int32", "int64"):
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
elif dtype == "bool":
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
else:
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
item[f"{unified_key}_is_pad"] = torch.tensor(True)
# 4. Tag which dataset this sample came from.
item["dataset_index"] = torch.tensor(ds_idx)
return item
def __repr__(self) -> str:
repo_ids = [c.repo_id for c in self._configs]
return (
f"NewMultiLeRobotDataset(\n"
f" repo_ids={repo_ids},\n"
f" num_frames={self.num_frames},\n"
f" num_episodes={self.num_episodes},\n"
f" unified_features={list(self._meta.features.keys())},\n"
f" weights={self._weights},\n"
f")"
)
+77
View File
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler:
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
Each iteration first selects a sub-dataset proportionally to its weight, then
uniformly samples a frame from that sub-dataset's valid index set. Episode
boundary information is respected so that dropped frames are excluded.
Args:
dataset_from_indices: Start index for each episode (global, flat).
dataset_to_indices: End index (exclusive) for each episode (global, flat).
dataset_membership: Which sub-dataset each episode belongs to (integer id).
dataset_weights: Relative sampling weight per sub-dataset.
episode_indices_to_use: If given, only episodes in this set are used.
drop_n_first_frames: Frames to skip at the start of each episode.
drop_n_last_frames: Frames to skip at the end of each episode.
shuffle: Whether to shuffle within each epoch.
num_samples: How many samples per epoch. Defaults to total valid frames.
generator: Optional torch.Generator for reproducibility.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
dataset_membership: list[int],
dataset_weights: list[float],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
num_samples: int | None = None,
generator: torch.Generator | None = None,
):
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
for ep_idx, (start, end, ds_id) in enumerate(
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
):
if episodes_to_use is not None and ep_idx not in episodes_to_use:
continue
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
self._per_dataset_indices[ds_id].extend(frame_range)
# Normalise weights (only over datasets that actually have frames).
raw_weights = list(dataset_weights[:n_datasets])
self._weights = torch.zeros(n_datasets)
for i, w in enumerate(raw_weights):
if len(self._per_dataset_indices[i]) > 0:
self._weights[i] = w
total_w = self._weights.sum()
if total_w > 0:
self._weights /= total_w
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
self._num_samples = num_samples if num_samples is not None else self._total_frames
self.shuffle = shuffle
self._generator = generator
def __iter__(self) -> Iterator[int]:
if not self.shuffle:
for ds_indices in self._per_dataset_indices:
yield from ds_indices
return
for _ in range(self._num_samples):
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
indices = self._per_dataset_indices[ds_id]
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
yield indices[local_idx]
def __len__(self) -> int:
return self._num_samples
+113
View File
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.nn.functional as F_nn
from torchvision.transforms import v2
from torchvision.transforms.v2 import (
Transform,
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
@dataclass
class DatasetTransformStepConfig:
"""Config for a single per-dataset transform step."""
type: str
kwargs: dict[str, Any] = field(default_factory=dict)
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
def register_dataset_transform(name: str):
"""Decorator to register a DatasetTransformStep by name."""
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
_DATASET_TRANSFORM_REGISTRY[name] = cls
return cls
return decorator
class DatasetTransformStep:
"""Base class for a single per-dataset transform applied to a sample dict."""
def __call__(self, sample: dict) -> dict:
raise NotImplementedError
@register_dataset_transform("pad_action")
class PadAction(DatasetTransformStep):
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
action = sample.get("action")
if action is None:
return sample
current = action.shape[-1]
if current < self.target_dim:
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
return sample
@register_dataset_transform("pad_state")
class PadState(DatasetTransformStep):
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
state = sample.get("observation.state")
if state is None:
return sample
current = state.shape[-1]
if current < self.target_dim:
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
return sample
@register_dataset_transform("resize_images")
class ResizeImages(DatasetTransformStep):
"""Resize all image/video camera tensors to (height, width)."""
def __init__(self, height: int, width: int):
self.size = (height, width)
def __call__(self, sample: dict) -> dict:
for key in list(sample.keys()):
if not key.startswith("observation.images."):
continue
img = sample[key]
if not isinstance(img, torch.Tensor) or img.ndim < 3:
continue
sample[key] = F.resize(img, self.size, antialias=True)
return sample
class DatasetTransformPipeline:
"""Sequential pipeline of DatasetTransformStep instances."""
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
self.steps: list[DatasetTransformStep] = []
if configs:
for cfg in configs:
self.steps.append(self._build(cfg))
@staticmethod
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
if cls is None:
raise ValueError(
f"Unknown dataset transform '{cfg.type}'. "
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
)
return cls(**cfg.kwargs)
def __call__(self, sample: dict) -> dict:
for step in self.steps:
sample = step(sample)
return sample
def __repr__(self) -> str:
return f"DatasetTransformPipeline(steps={self.steps})"
+2 -112
View File
@@ -1,5 +1,4 @@
#!/usr/bin/env python
from __future__ import annotations
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
@@ -22,11 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from lerobot.data_processing.data_annotations.subtask_annotations import EpisodeSkills
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from typing import Any
import datasets
import numpy as np
@@ -1221,111 +1216,6 @@ def find_float_index(target, float_list, threshold=1e-6):
return -1
def create_subtasks_dataframe(
annotations: dict[int, EpisodeSkills],
) -> tuple[pd.DataFrame, dict[str, int]]:
"""
Create a subtasks DataFrame from skill annotations.
Args:
annotations: Dictionary of episode skills
Returns:
Tuple of (subtasks_df, skill_to_subtask_idx mapping)
"""
# Collect all unique skill names
all_skill_names: set[str] = set()
for episode_skills in annotations.values():
for skill in episode_skills.skills:
all_skill_names.add(skill.name)
# Build subtasks DataFrame
subtask_data = []
for i, skill_name in enumerate(sorted(all_skill_names)):
subtask_data.append(
{
"subtask": skill_name,
"subtask_index": i,
}
)
if not subtask_data:
subtasks_df = pd.DataFrame(columns=["subtask", "subtask_index"]).set_index("subtask")
else:
subtasks_df = pd.DataFrame(subtask_data).set_index("subtask")
# Build skill name to subtask_index mapping
skill_to_subtask_idx = {
skill_name: int(subtasks_df.loc[skill_name, "subtask_index"]) for skill_name in all_skill_names
}
return subtasks_df, skill_to_subtask_idx
def save_subtasks(
subtasks_df: pd.DataFrame,
dataset_root: Path,
) -> None:
"""Save subtasks to subtasks.parquet."""
output_path = dataset_root / "meta" / "subtasks.parquet"
output_path.parent.mkdir(parents=True, exist_ok=True)
subtasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
def create_subtask_index_array(
dataset: LeRobotDataset,
annotations: dict[int, EpisodeSkills],
skill_to_subtask_idx: dict[str, int],
) -> np.ndarray:
"""
Create a subtask_index array for each frame based on skill annotations.
Args:
dataset: The LeRobot dataset
annotations: Dictionary of episode skills
skill_to_subtask_idx: Mapping from skill name to subtask_index
Returns:
Array of subtask indices for each frame in the dataset
"""
# Array to store subtask index for each frame
# Initialize with -1 to indicate unannotated frames
full_dataset_length = len(dataset)
subtask_indices = np.full(full_dataset_length, -1, dtype=np.int64)
# Assign subtask_index for each annotated episode
fps = float(dataset.meta.fps)
for ep_idx, episode_skills in annotations.items():
skills = episode_skills.skills
# Get episode frame range
ep = dataset.meta.episodes[ep_idx]
ep_from = int(ep["dataset_from_index"])
ep_to = int(ep["dataset_to_index"])
# Process each frame in the episode (compute timestamp from index to avoid loading video)
for frame_idx in range(ep_from, ep_to):
timestamp = (frame_idx - ep_from) / fps
# Find which skill covers this timestamp (inline to avoid circular import)
skill = None
for s in skills:
if s.start <= timestamp < s.end:
skill = s
break
if timestamp >= s.end and s == skills[-1]:
skill = s
break
if not skill and skills:
skill = skills[-1]
if skill and skill.name in skill_to_subtask_idx:
subtask_idx = skill_to_subtask_idx[skill.name]
subtask_indices[frame_idx] = subtask_idx
return subtask_indices
class LookBackError(Exception):
"""
Exception raised when trying to look back in the history of a Backtrackable object.
@@ -1389,7 +1279,7 @@ class Backtrackable[T]:
self._history = history
self._lookahead = lookahead
def __iter__(self) -> Backtrackable[T]:
def __iter__(self) -> "Backtrackable[T]":
return self
def __next__(self) -> T:
+99
View File
@@ -346,6 +346,105 @@ class LiberoEnv(EnvConfig):
return kwargs
@EnvConfig.register_subclass("libero_plus")
@dataclass
class LiberoPlusEnv(LiberoEnv):
"""Alias config for LIBERO-plus benchmarks.
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
config reuses the existing LIBERO env implementation while making intent explicit
in experiment configs (`env.type=libero_plus`).
"""
task: str = "libero_spatial"
@EnvConfig.register_subclass("robocasa")
@dataclass
class RoboCasaEnv(EnvConfig):
"""RoboCasa kitchen composite-task environments.
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
action space and a structured pixel + state observation dict.
Selected benchmark tasks (3 short + 2 long):
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
Long: PrepareCoffee, RestockPantry
"""
task: str = "PickPlaceCounterToCabinet"
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
fps: int = 20
episode_length: int = 500
image_size: int = 128
split: str = "target" # "pretrain" or "target"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agentview_left": f"{OBS_IMAGES}.agentview_left",
"agentview_right": f"{OBS_IMAGES}.agentview_right",
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
"robot_state": OBS_STATE,
}
)
def __post_init__(self):
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
self.features[cam] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
)
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {"split": self.split}
@EnvConfig.register_subclass("robomme")
@dataclass
class RoboMMEEnv(EnvConfig):
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
Uses BenchmarkEnvBuilder from the robomme package.
"""
task: str = "PickXtimes"
fps: int = 10
episode_length: int = 300
action_space: str = "joint_angle"
dataset_split: str = "test"
task_ids: list[int] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"front_rgb": f"{OBS_IMAGES}.front",
"wrist_rgb": f"{OBS_IMAGES}.wrist",
OBS_STATE: OBS_STATE,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"action_space": self.action_space,
"dataset": self.dataset_split,
}
@EnvConfig.register_subclass("metaworld")
@dataclass
class MetaworldEnv(EnvConfig):
+50 -3
View File
@@ -20,11 +20,21 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.configs import (
AlohaEnv,
EnvConfig,
HubEnvConfig,
IsaaclabArenaEnv,
LiberoEnv,
LiberoPlusEnv,
PushtEnv,
RoboCasaEnv,
RoboMMEEnv,
)
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
elif env_type == "libero_plus":
return LiberoPlusEnv(**kwargs)
elif env_type == "robocasa":
return RoboCasaEnv(**kwargs)
elif env_type == "robomme":
return RoboMMEEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
preprocessor_steps.append(RoboCasaProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
@@ -181,6 +201,33 @@ def make_env(
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "robocasa" in cfg.type:
from lerobot.envs.robocasa import create_robocasa_envs
tasks = cfg.tasks if cfg.tasks else [cfg.task]
return create_robocasa_envs(
tasks=tasks,
n_envs=n_envs,
image_size=cfg.image_size,
split=cfg.split,
episode_length=cfg.episode_length,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
elif "robomme" in cfg.type:
from lerobot.envs.robomme import create_robomme_envs
return create_robomme_envs(
task=cfg.task,
n_envs=n_envs,
action_space_type=cfg.action_space,
dataset=cfg.dataset_split,
episode_length=cfg.episode_length,
task_ids=cfg.task_ids,
env_cls=env_cls,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
+8 -2
View File
@@ -26,8 +26,14 @@ import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
try:
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
except ImportError:
# LIBERO-plus may be installed from source with an extra nested package level.
from libero.libero.libero import benchmark, get_libero_path
from libero.libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
+273
View File
@@ -0,0 +1,273 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Action layout (flat 12D, normalized to [-1, 1]):
# [0:3] end_effector_position (delta x, y, z)
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
# [6:7] gripper_close (open=-1, close=+1)
# [7:11] base_motion (x, y, theta, torso_height)
# [11:12] control_mode (arm=-1, base=+1)
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
# Proprioceptive state layout (flat 16D):
# [0:2] gripper_qpos
# [2:5] base_position
# [5:9] base_rotation (quaternion)
# [9:12] end_effector_position_relative
# [12:16] end_effector_rotation_relative (quaternion)
STATE_DIM = 16
# Obs dict keys from RoboCasaGymEnv.get_observation()
_CAM_KEYS = (
"video.robot0_agentview_left",
"video.robot0_agentview_right",
"video.robot0_eye_in_hand",
)
_STATE_KEYS_ORDERED = (
"state.gripper_qpos", # (2,)
"state.base_position", # (3,)
"state.base_rotation", # (4,)
"state.end_effector_position_relative", # (3,)
"state.end_effector_rotation_relative", # (4,)
)
# Mapping from video.* key → short image name used in features_map
CAM_KEY_TO_NAME = {
"video.robot0_agentview_left": "agentview_left",
"video.robot0_agentview_right": "agentview_right",
"video.robot0_eye_in_hand": "eye_in_hand",
}
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
return {
"action.end_effector_position": flat[0:3],
"action.end_effector_rotation": flat[3:6],
"action.gripper_close": flat[6:7],
"action.base_motion": flat[7:11],
"action.control_mode": flat[11:12],
}
class RoboCasaEnv(gym.Env):
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
and a structured observation dict compatible with LeRobot policies.
Observations returned by step/reset:
{
"pixels": {
"agentview_left": (H, W, 3) uint8,
"agentview_right": (H, W, 3) uint8,
"eye_in_hand": (H, W, 3) uint8,
},
"robot_state": (16,) float32,
}
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
"""
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
def __init__(
self,
task: str,
split: str = "target",
image_size: int = 128,
render_mode: str = "rgb_array",
episode_length: int = 500,
**gym_kwargs: Any,
):
super().__init__()
# Lazy import — robocasa is optional
import robocasa.environments # noqa: F401 — registers all gym envs
self.task = task
self.render_mode = render_mode
self.image_size = image_size
self._max_episode_steps = episode_length
self._step_count = 0
self._env = gym.make(
f"robocasa/{task}",
split=split,
camera_widths=image_size,
camera_heights=image_size,
**gym_kwargs,
)
# Flat 12D Box action space
self.action_space = spaces.Box(
low=ACTION_LOW,
high=ACTION_HIGH,
shape=(ACTION_DIM,),
dtype=np.float32,
)
images = {
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
for name in CAM_KEY_TO_NAME.values()
}
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Box(
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
),
}
)
def _format_obs(self, raw_obs: dict) -> dict:
pixels = {
CAM_KEY_TO_NAME[k]: raw_obs[k]
for k in _CAM_KEYS
if k in raw_obs
}
state_parts = [
np.asarray(raw_obs[k], dtype=np.float32)
for k in _STATE_KEYS_ORDERED
if k in raw_obs
]
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
return {"pixels": pixels, "robot_state": robot_state}
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
super().reset(seed=seed)
self._step_count = 0
raw_obs, info = self._env.reset(seed=seed)
info.setdefault("is_success", False)
info["task"] = self.task
return self._format_obs(raw_obs), info
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
)
action_dict = _flat_to_action_dict(action)
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
self._step_count += 1
is_success = bool(info.get("success", False))
terminated = terminated or is_success
if self._step_count >= self._max_episode_steps:
truncated = True
info.update({"task": self.task, "is_success": is_success})
obs = self._format_obs(raw_obs)
if terminated or truncated:
info["final_info"] = {"task": self.task, "is_success": is_success}
return obs, reward, terminated, truncated, info
def render(self) -> np.ndarray | None:
if self.render_mode == "rgb_array":
return self._env.render()
return None
def close(self) -> None:
self._env.close()
def _make_env_fns(
*,
task: str,
n_envs: int,
image_size: int,
split: str,
episode_length: int,
gym_kwargs: dict[str, Any],
) -> list[Callable[[], RoboCasaEnv]]:
"""Build n_envs factory callables for a single task."""
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
return RoboCasaEnv(
task=task,
split=split,
image_size=image_size,
episode_length=episode_length,
**gym_kwargs,
)
return [partial(_make, i) for i in range(n_envs)]
def create_robocasa_envs(
tasks: str | Sequence[str],
n_envs: int,
image_size: int = 128,
split: str = "target",
episode_length: int = 500,
gym_kwargs: dict[str, Any] | None = None,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboCasa environments.
Args:
tasks: A single task name or list of task names (without "robocasa/" prefix).
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
n_envs: Number of parallel envs per task.
image_size: Square image resolution for all cameras.
split: RoboCasa dataset split — "pretrain" or "target".
episode_length: Max steps per episode before truncation.
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
Returns:
dict[task_name][task_id=0] -> vec_env
"""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
if isinstance(tasks, str):
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
else:
task_list = [str(t).strip() for t in tasks if str(t).strip()]
if not task_list:
raise ValueError("`tasks` must contain at least one task name.")
gym_kwargs = dict(gym_kwargs or {})
out: dict[str, dict[int, Any]] = defaultdict(dict)
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
for task in task_list:
fns = _make_env_fns(
task=task,
n_envs=n_envs,
image_size=image_size,
split=split,
episode_length=episode_length,
gym_kwargs=gym_kwargs,
)
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
print(f" Built vec env | task={task} | n_envs={n_envs}")
return {suite: dict(task_map) for suite, task_map in out.items()}
+154
View File
@@ -0,0 +1,154 @@
"""RoboMME environment wrapper for LeRobot evaluation.
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
``VectorEnv`` suitable for ``lerobot_eval``.
RoboMME tasks:
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
"""
from __future__ import annotations
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
ROBOMME_TASKS = [
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
]
class RoboMMEGymEnv(gym.Env):
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
metadata = {"render_modes": ["rgb_array"]}
def __init__(
self,
task: str = "PickXtimes",
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_idx: int = 0,
max_steps: int = 300,
):
super().__init__()
from robomme.env_record_wrapper import BenchmarkEnvBuilder
self._task = task
self._action_space_type = action_space_type
self._dataset = dataset
self._episode_idx = episode_idx
self._max_steps = max_steps
self._builder = BenchmarkEnvBuilder(
env_id=task,
dataset=dataset,
action_space=action_space_type,
gui_render=False,
max_steps=max_steps,
)
self._env = None
action_dim = 8 if action_space_type == "joint_angle" else 7
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
self.observation_space = spaces.Dict({
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
})
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self._env = self._builder.make_env_for_episode(
episode_idx=self._episode_idx, max_steps=self._max_steps,
)
obs, info = self._env.reset()
return self._convert_obs(obs), self._convert_info(info)
def step(self, action):
obs, reward, terminated, truncated, info = self._env.step(action)
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
status = info.get("status", "ongoing")
is_success = status == "success"
conv_info = self._convert_info(info)
conv_info["is_success"] = is_success
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
def _convert_obs(self, obs: dict) -> dict:
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
state = np.concatenate([joint, gripper])
return {
"front_rgb": front_rgb,
"wrist_rgb": wrist_rgb,
"state": state,
}
def _convert_info(self, info: dict) -> dict:
return {
"status": info.get("status", "ongoing"),
"task_goal": info.get("task_goal", ""),
}
def create_robomme_envs(
task: str,
n_envs: int = 1,
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_length: int = 300,
task_ids: list[int] | None = None,
env_cls=None,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create vectorized RoboMME environments for evaluation.
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
"""
if env_cls is None:
env_cls = gym.vector.SyncVectorEnv
if task_ids is None:
task_ids = [0]
suite_name = "robomme"
envs_by_task = {}
for task_id in task_ids:
def _make_one(ep_idx=task_id):
return RoboMMEGymEnv(
task=task,
action_space_type=action_space_type,
dataset=dataset,
episode_idx=ep_idx,
max_steps=episode_length,
)
vec = env_cls(
[_make_one for _ in range(n_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
envs_by_task[task_id] = vec
return {suite_name: envs_by_task}
+38
View File
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
return result
@dataclass
@ProcessorStepRegistry.register(name="robocasa_processor")
class RoboCasaProcessorStep(ObservationProcessorStep):
"""
Processes RoboCasa observations into LeRobot format.
The RoboCasaEnv wrapper returns:
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
- ``observation.robot_state``: (B, 16) float32 proprioception
This step remaps them to:
- ``observation.images.<cam_name>`` (unchanged tensor)
- ``observation.state`` (robot_state renamed)
"""
def _process_observation(self, observation: dict) -> dict:
processed = {}
obs_prefix = OBS_PREFIX # "observation."
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}."):
# Already in the right place; pass through
processed[key] = value
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
# Rename robot_state → observation.state
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
return processed
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def observation(self, observation: dict) -> dict:
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
@@ -1,160 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Automatic Skill Annotation for LeRobot Datasets.
This script performs automatic subtask/skill labeling for ANY LeRobot dataset using
Vision-Language Models (VLMs). It segments each robot demonstration into short atomic
skills (1-3 seconds each) and creates a new dataset with subtask annotations.
The pipeline:
1. Loads a LeRobot dataset (local or from HuggingFace Hub)
2. For each episode, extracts video frames
3. Uses a VLM to identify skill boundaries and labels
4. Creates a subtasks.parquet file with unique subtasks
5. Adds a subtask_index feature to the dataset
Supported VLMs (modular design): Qwen2-VL, Qwen3-VL, Qwen3.5-VL (see vlm_annotations.py).
Usage:
lerobot-dataset-subtask-annotate --repo_id=user/dataset --video_key=observation.images.base ...
lerobot-dataset-subtask-annotate --data_dir=/path/to/dataset --video_key=observation.images.base ...
"""
from dataclasses import dataclass
from pathlib import Path
import torch
from lerobot.configs import parser
from lerobot.data_processing.data_annotations.subtask_annotations import (
SkillAnnotator,
save_skill_annotations,
)
from lerobot.data_processing.data_annotations.vlm_annotations import get_vlm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@dataclass
class SubtaskAnnotateConfig:
"""Configuration for automatic subtask/skill annotation with VLMs."""
# Data source: provide exactly one of data_dir (local) or repo_id (Hub)
data_dir: str | None = None
repo_id: str | None = None
# Video observation key (e.g. observation.images.base)
video_key: str = "observation.images.base"
# VLM model name (default: Qwen/Qwen2-VL-7B-Instruct)
model: str = "Qwen/Qwen2-VL-7B-Instruct"
device: str = "cuda"
dtype: str = "bfloat16"
batch_size: int = 8
# Episode selection (default: all)
episodes: list[int] | None = None
skip_existing: bool = False
# Output
output_dir: str | None = None
output_repo_id: str | None = None
push_to_hub: bool = False
# Closed vocabulary: comma-separated labels (e.g. "label1,label2,label3")
subtask_labels: str | None = None
# Disable timer overlay on video (by default a timer is drawn for the VLM)
no_timer_overlay: bool = False
@parser.wrap()
def subtask_annotate(cfg: SubtaskAnnotateConfig):
"""
Run automatic skill annotation on a LeRobot dataset using a VLM.
Args:
cfg: SubtaskAnnotateConfig with data source, model, and output options.
"""
if (cfg.data_dir is None) == (cfg.repo_id is None):
raise ValueError("Provide exactly one of --data_dir or --repo_id")
# Parse comma-separated subtask labels into a list (or None)
subtask_labels_list: list[str] | None = None
if cfg.subtask_labels and cfg.subtask_labels.strip():
subtask_labels_list = [s.strip() for s in cfg.subtask_labels.split(",") if s.strip()]
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map[cfg.dtype]
print("Loading dataset...")
if cfg.data_dir:
dataset = LeRobotDataset(repo_id="local/dataset", root=cfg.data_dir, download_videos=False)
else:
dataset = LeRobotDataset(repo_id=cfg.repo_id, download_videos=True)
print(f" Loaded dataset with {dataset.meta.total_episodes} episodes")
if cfg.video_key not in dataset.meta.video_keys:
available = ", ".join(dataset.meta.video_keys)
raise ValueError(f"Video key '{cfg.video_key}' not found. Available: {available}")
print(f"Initializing VLM: {cfg.model}...")
vlm = get_vlm(cfg.model, cfg.device, torch_dtype)
add_timer_overlay = not cfg.no_timer_overlay
annotator = SkillAnnotator(
vlm=vlm,
batch_size=cfg.batch_size,
add_timer_overlay=add_timer_overlay,
)
print(f"Processing with batch size: {cfg.batch_size}")
annotations = annotator.annotate_dataset(
dataset=dataset,
video_key=cfg.video_key,
episodes=cfg.episodes,
skip_existing=cfg.skip_existing,
subtask_labels=subtask_labels_list,
)
output_dir = Path(cfg.output_dir) if cfg.output_dir else None
output_repo_id = cfg.output_repo_id
new_dataset = save_skill_annotations(dataset, annotations, output_dir, output_repo_id)
total_skills = sum(len(ann.skills) for ann in annotations.values())
print("\nAnnotation complete!")
print(f"Episodes annotated: {len(annotations)}")
print(f"Total subtasks identified: {total_skills}")
print(f"Dataset with subtask_index saved to: {new_dataset.root}")
if cfg.push_to_hub:
if cfg.data_dir:
print("Warning: --push_to_hub requires --repo_id, skipping...")
else:
print("Pushing to HuggingFace Hub...")
try:
new_dataset.push_to_hub(branch="subtasks")
print(f" Pushed to {output_repo_id or cfg.repo_id}")
except Exception as e:
print(f"Push failed: {e}")
def main():
"""CLI entry point that parses config and runs subtask annotation."""
subtask_annotate()
if __name__ == "__main__":
main()
+17 -4
View File
@@ -29,7 +29,8 @@ from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
@@ -343,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
if isinstance(dataset, NewMultiLeRobotDataset):
shuffle = False
sampler = WeightedEpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
elif drop_n_last > 0:
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
else:
@@ -360,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle and not cfg.dataset.streaming,
shuffle=shuffle and not getattr(cfg.dataset, "streaming", False),
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
-89
View File
@@ -89,92 +89,3 @@ LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
def format_subtask_labels_section(subtask_labels: list[str]) -> str:
"""Format a list of subtask labels for the closed-vocabulary section of the prompt."""
return "\n ".join(f'"{label}"' for label in subtask_labels)
SKILL_SEGMENTATION_PROMPT_TEMPLATE = """# Role
You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations.
# Video duration (critical)
The total video length is **{video_duration_seconds} seconds** ({video_duration_mm_ss}). All "start" and "end" values in your JSON must be numeric seconds in the range [0.0, {video_duration_seconds}]. The last skill's "end" must be exactly **{video_duration_seconds}**. Do not stop earlier.
# Task
{goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should:
- Last approximately 1-3 seconds (or longer if the action takes longer)
- Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper")
- Have precise start and end timestamps in seconds (float)
# Requirements
1. **Atomic Actions**: Each skill should be a single, indivisible action
2. **Complete Coverage**: Skills must cover the entire video from 0.0 to {video_duration_seconds} seconds with no gaps
3. **Boundary Consistency**: The end of one skill equals the start of the next
4. **Natural Language**: Use clear, descriptive names for each skill
5. **Timestamps**: Use seconds as floats (e.g. 12.5) for all timestamps; the last "end" must be {video_duration_seconds}. If the video has a visible timer in the corner showing elapsed time in seconds, use it to report accurate start and end times for each skill.
# Subtask Label Set (Closed Vocabulary)
You MUST strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
[
{subtask_labels_section}
]
The video shows one successful execution of all subtasks in a logical order.
# Ground-Truth Semantics (Very Important)
Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks.
- A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask.
- A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration.
If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask.
# Hard Constraints & Logic
1. **Continuous Coverage (No Gaps):**
- The entire video from 0.0 to {video_duration_seconds} seconds must be covered by subtasks.
- There can be no gaps between subtasks.
- If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it.
2. **Boundary Consistency:**
- The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask.
- Boundaries must coincide with a real visual state transition, not just a convenient time split.
3. **Chronological Order, One Occurrence Each:**
- This is a single successful demonstration.
- Each subtask from the vocabulary appears **exactly once**, in the correct logical order.
- **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video.
4. **Reject Uniform Segmentation (Important):**
- Do NOT simply divide the video into equal or nearly equal time chunks.
- If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries.
- Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare).
5. **Timestamps (critical):**
- Use numeric seconds (float) in the JSON, e.g. 0.0, 5.2, 12.8.
- The first subtask always starts at 0.0.
- The last subtask must end at exactly {video_duration_seconds} (the full video length).
- **Time is displayed inside the video**: a visible timer in one corner shows the elapsed time in seconds (from 0.0 to the end). Use this on-screen timer to set accurate start and end times for each skill.
Format this as a bullet list.
# Output Format
output ONLY valid JSON with this exact structure. The last skill's "end" MUST be exactly {video_duration_seconds}. Use the timestamps you read from the visible timer in the video:
```json
{{
"skills": [
{{"name": "first skill", "start": 0.0, "end": 5.0}},
{{"name": "second skill", "start": 5.0, "end": 12.0}},
{{"name": "last skill", "start": 12.0, "end": {video_duration_seconds}}}
]
}}
```
The first skill must start at 0.0 and the last skill must end at **{video_duration_seconds}** (the total video duration in seconds).
# Strict Structural Rule
This video contains exactly ALL subtasks given to you.
Each segment must use a unique label from the vocabulary.
No label may be repeated.
"""
+1 -1
View File
@@ -74,7 +74,7 @@ _peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_unitree_sdk_available = is_package_available("unitree-sdk2", "unitree_sdk2py")
_pygame_available = is_package_available("pygame")
-168
View File
@@ -23,18 +23,11 @@ These tests verify that:
- Subtask handling gracefully handles missing data
"""
import numpy as np
import pandas as pd
import pytest
import torch
from lerobot.data_processing.data_annotations.subtask_annotations import EpisodeSkills, Skill
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
create_subtask_index_array,
create_subtasks_dataframe,
save_subtasks,
)
class TestSubtaskDataset:
@@ -195,164 +188,3 @@ class TestSubtaskEdgeCases:
)
else:
subtask_map[idx] = subtask
class TestCreateSubtasksDataframe:
"""Tests for create_subtasks_dataframe in utils."""
def test_empty_annotations(self):
"""Empty annotations produce empty DataFrame and empty mapping."""
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe({})
assert len(subtasks_df) == 0
assert list(subtasks_df.columns) == ["subtask_index"]
assert skill_to_subtask_idx == {}
def test_single_episode_single_skill(self):
"""Single episode with one skill produces one row and correct mapping."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Pick",
skills=[Skill("pick", 0.0, 1.0)],
),
}
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
assert len(subtasks_df) == 1
assert subtasks_df.index.tolist() == ["pick"]
assert subtasks_df.loc["pick", "subtask_index"] == 0
assert skill_to_subtask_idx == {"pick": 0}
def test_multiple_episodes_overlapping_skills(self):
"""Multiple episodes with overlapping skill names yield unique sorted skills."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Ep0",
skills=[
Skill("place", 0.0, 0.5),
Skill("pick", 0.5, 1.0),
],
),
1: EpisodeSkills(
episode_index=1,
description="Ep1",
skills=[Skill("pick", 0.0, 1.0)],
),
}
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
# Sorted order: pick, place
assert subtasks_df.index.tolist() == ["pick", "place"]
assert int(subtasks_df.loc["pick", "subtask_index"]) == 0
assert int(subtasks_df.loc["place", "subtask_index"]) == 1
assert skill_to_subtask_idx["pick"] == 0
assert skill_to_subtask_idx["place"] == 1
def test_skills_sorted_alphabetically(self):
"""Subtask rows are in alphabetical order by skill name."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Ep",
skills=[
Skill("z_final", 0.0, 0.33),
Skill("a_first", 0.33, 0.66),
Skill("m_mid", 0.66, 1.0),
],
),
}
subtasks_df, _ = create_subtasks_dataframe(annotations)
assert subtasks_df.index.tolist() == ["a_first", "m_mid", "z_final"]
assert list(subtasks_df["subtask_index"]) == [0, 1, 2]
class TestSaveSubtasks:
"""Tests for save_subtasks in utils."""
def test_save_subtasks_creates_file(self, tmp_path):
"""save_subtasks writes meta/subtasks.parquet and creates parent dir."""
subtasks_df = pd.DataFrame(
[{"subtask": "pick", "subtask_index": 0}, {"subtask": "place", "subtask_index": 1}]
).set_index("subtask")
save_subtasks(subtasks_df, tmp_path)
out = tmp_path / "meta" / "subtasks.parquet"
assert out.exists()
read_df = pd.read_parquet(out)
pd.testing.assert_frame_equal(read_df.reset_index(), subtasks_df.reset_index())
def test_save_subtasks_content_matches(self, tmp_path):
"""Saved parquet round-trips with same content."""
subtasks_df = pd.DataFrame(
[{"subtask": "a", "subtask_index": 0}, {"subtask": "b", "subtask_index": 1}]
).set_index("subtask")
save_subtasks(subtasks_df, tmp_path)
read_df = pd.read_parquet(tmp_path / "meta" / "subtasks.parquet")
assert read_df.index.tolist() == subtasks_df.index.tolist()
assert list(read_df["subtask_index"]) == list(subtasks_df["subtask_index"])
class TestCreateSubtaskIndexArray:
"""Tests for create_subtask_index_array in utils."""
@pytest.fixture
def dataset_with_episodes(self, tmp_path, empty_lerobot_dataset_factory):
"""Dataset with two episodes (10 frames each) for index-array tests."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "subtask_idx", features=features)
for _ in range(10):
dataset.add_frame({"state": torch.randn(2), "task": "Task A"})
dataset.save_episode()
for _ in range(10):
dataset.add_frame({"state": torch.randn(2), "task": "Task B"})
dataset.save_episode()
dataset.finalize()
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_unannotated_all_minus_one(self, dataset_with_episodes):
"""With no annotations, all frame indices are -1."""
skill_to_subtask_idx = {"pick": 0, "place": 1}
arr = create_subtask_index_array(dataset_with_episodes, {}, skill_to_subtask_idx)
assert len(arr) == len(dataset_with_episodes)
assert arr.dtype == np.int64
assert np.all(arr == -1)
def test_annotated_episode_assigns_by_timestamp(self, dataset_with_episodes):
"""Frames in an annotated episode get subtask index from skill time ranges."""
# Dataset uses DEFAULT_FPS=30. Episode 0: 10 frames -> timestamps 0, 1/30, ..., 9/30 (~0.3s).
# Skills: "pick" [0, 0.2), "place" [0.2, 0.5). At 30 fps: 0.2s = 6 frames, so frames 0-5 = pick, 6-9 = place.
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Pick and place",
skills=[
Skill("pick", 0.0, 0.2), # frames 0-5 at 30 fps
Skill("place", 0.2, 0.5), # frames 6-9 at 30 fps
],
),
}
skill_to_subtask_idx = {"pick": 0, "place": 1}
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
assert len(arr) == 20
# Episode 0: from_index=0, to_index=10 at 30 fps
for i in range(6):
assert arr[i] == 0, f"frame {i} should be pick"
for i in range(6, 10):
assert arr[i] == 1, f"frame {i} should be place"
# Episode 1 not annotated
for i in range(10, 20):
assert arr[i] == -1
def test_partial_annotations_leave_others_minus_one(self, dataset_with_episodes):
"""Only annotated episodes get non -1 indices; others stay -1."""
annotations = {
1: EpisodeSkills(
episode_index=1,
description="Place only",
skills=[Skill("place", 0.0, 1.0)],
),
}
skill_to_subtask_idx = {"place": 0}
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
for i in range(10):
assert arr[i] == -1
for i in range(10, 20):
assert arr[i] == 0
+176
View File
@@ -0,0 +1,176 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RoboCasa LeRobot integration.
Requires: robocasa installed + kitchen assets downloaded.
Tests are skipped automatically if robocasa is not available.
"""
from __future__ import annotations
import numpy as np
import pytest
# Skip entire module if robocasa is not installed or assets are missing
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
# The 5 benchmark tasks (3 short + 2 long)
BENCHMARK_TASKS = [
"PickPlaceCounterToCabinet", # short
"PrepareToast", # short
"CoffeeSetupMug", # short
"PrepareCoffee", # long
"RestockPantry", # long
]
SHORT_TASKS = BENCHMARK_TASKS[:3]
LONG_TASKS = BENCHMARK_TASKS[3:]
IMAGE_SIZE = 64 # small for fast tests
@pytest.fixture(scope="module")
def single_env():
"""Shared env instance for lightweight tests."""
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
yield env
env.close()
class TestRoboCasaEnvSpaces:
def test_action_space_is_flat_box(self, single_env):
import gymnasium as gym
assert isinstance(single_env.action_space, gym.spaces.Box)
assert single_env.action_space.shape == (ACTION_DIM,)
assert single_env.action_space.dtype == np.float32
def test_action_bounds(self, single_env):
assert np.all(single_env.action_space.low == -1.0)
assert np.all(single_env.action_space.high == 1.0)
def test_observation_space_has_pixels_and_state(self, single_env):
import gymnasium as gym
assert isinstance(single_env.observation_space, gym.spaces.Dict)
assert "pixels" in single_env.observation_space.spaces
assert "robot_state" in single_env.observation_space.spaces
def test_observation_space_cameras(self, single_env):
pixels_space = single_env.observation_space["pixels"]
expected_cams = set(CAM_KEY_TO_NAME.values())
assert set(pixels_space.spaces.keys()) == expected_cams
def test_state_dim(self, single_env):
state_space = single_env.observation_space["robot_state"]
assert state_space.shape == (STATE_DIM,)
class TestRoboCasaEnvReset:
def test_reset_returns_obs_and_info(self, single_env):
obs, info = single_env.reset()
assert isinstance(obs, dict)
assert isinstance(info, dict)
def test_reset_obs_has_pixels(self, single_env):
obs, _ = single_env.reset()
assert "pixels" in obs
for cam_name in CAM_KEY_TO_NAME.values():
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
def test_reset_obs_image_shape(self, single_env):
obs, _ = single_env.reset()
for cam_name, img in obs["pixels"].items():
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
assert img.dtype == np.uint8
def test_reset_obs_state_shape(self, single_env):
obs, _ = single_env.reset()
assert obs["robot_state"].shape == (STATE_DIM,)
assert obs["robot_state"].dtype == np.float32
def test_reset_info_has_task(self, single_env):
_, info = single_env.reset()
assert "task" in info
assert info["task"] == "PickPlaceCounterToCabinet"
class TestRoboCasaEnvStep:
def test_step_10_random_actions(self, single_env):
single_env.reset()
for _ in range(10):
action = single_env.action_space.sample()
obs, reward, terminated, truncated, info = single_env.step(action)
assert obs["robot_state"].shape == (STATE_DIM,)
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
def test_step_bad_action_raises(self, single_env):
single_env.reset()
with pytest.raises(ValueError, match="Expected 1-D action"):
single_env.step(np.zeros((2, ACTION_DIM)))
def test_step_info_has_is_success(self, single_env):
single_env.reset()
_, _, _, _, info = single_env.step(single_env.action_space.sample())
assert "is_success" in info
class TestRoboCasaConfig:
def test_robocasa_env_config(self):
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
from lerobot.configs.types import FeatureType
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
assert cfg.type == "robocasa"
# action feature
assert "action" in cfg.features
assert cfg.features["action"].shape == (ACTION_DIM,)
# camera features
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
assert cam in cfg.features
assert cfg.features[cam].type == FeatureType.VISUAL
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
# state feature
assert "robot_state" in cfg.features
assert cfg.features["robot_state"].shape == (STATE_DIM,)
def test_make_env_config_robocasa(self):
from lerobot.envs.factory import make_env_config
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
assert cfg.type == "robocasa"
class TestRoboCasaProcessorStep:
def test_processor_remaps_keys(self):
import torch
from lerobot.processor.env_processor import RoboCasaProcessorStep
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
step = RoboCasaProcessorStep()
B = 2
obs = {
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"observation.robot_state": torch.zeros(B, STATE_DIM),
}
out = step._process_observation(obs)
assert OBS_STATE in out
assert out[OBS_STATE].dtype == torch.float32
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
assert f"{OBS_IMAGES}.{cam}" in out