Compare commits

..

2 Commits

Author SHA1 Message Date
pepijn db5c26f07d feat(envs): add LIBERO-plus integration for evaluation benchmarks
Add LiberoPlusEnv config (subclass of LiberoEnv), register libero_plus
env type in factory, add import fallbacks for LIBERO-plus package
structure, and add libero_plus optional dependency group in pyproject.toml.

Made-with: Cursor
2026-03-12 04:31:09 +00:00
Pepijn 8904768db4 feat(envs): add RoboCasa composite-task benchmark integration
Integrates 5 selected RoboCasa kitchen tasks (3 short + 2 long) as a
LeRobot benchmark environment, following the same pattern as Libero.

Selected tasks:
  Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
  Long:  PrepareCoffee, RestockPantry

Changes:
- envs/robocasa.py: RoboCasaEnv wrapper with flat 12D Box action space,
  3-camera pixel obs, and 16D proprioceptive state
- envs/configs.py: RoboCasaEnv config with features_map
- envs/factory.py: wire robocasa into make_env + make_env_pre_post_processors
- processor/env_processor.py: RoboCasaProcessorStep for obs key remapping
- tests/test_robocasa_env.py: full test suite (auto-skips if assets missing)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-09 17:08:32 +01:00
155 changed files with 3753 additions and 3176 deletions
+4 -4
View File
@@ -2,7 +2,7 @@
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) and our [AI policy](https://github.com/huggingface/lerobot/blob/main/AI_POLICY.md).
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
## Ways to Contribute
@@ -32,7 +32,7 @@ git remote add upstream https://github.com/huggingface/lerobot.git
### 2. Environment Installation
Please follow our [Installation Guide](https://huggingface.co/docs/lerobot/installation) for the environment setup & installation from source.
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
## Running Tests & Quality Checks
@@ -75,8 +75,8 @@ pytest -sv tests/test_specific_feature.py
Use the templates for required fields and examples.
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
One member of the LeRobot team will then review your contribution.
+1 -1
View File
@@ -165,7 +165,7 @@ If you are referencing our research or the academic paper, please also cite our
## Contribute
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
<p align="center">
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
-2
View File
@@ -18,8 +18,6 @@
# docker build -f docker/Dockerfile.user -t lerobot-user .
# docker run -it --rm lerobot-user
# With USB physical access : docker run -it --device=/dev/ -v /dev/:/dev/ --rm lerobot-user
# Configure the base image
ARG PYTHON_VERSION=3.12
FROM python:${PYTHON_VERSION}-slim
-2
View File
@@ -19,8 +19,6 @@
title: Multi GPU training
- local: peft_training
title: Training with PEFT (e.g., LoRA)
- local: rename_map
title: Using Rename Map and Empty Cameras
title: "Tutorials"
- sections:
- local: lerobot-dataset-v3
+1 -1
View File
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
+7 -11
View File
@@ -204,26 +204,22 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
Your dataset includes:
**Your Actions (2 features)**:
**Your Actions (2 things)**:
- `linear_velocity`: How much you moved forward/backward
- `angular_velocity`: How much you turned left/right
- How much you moved forward/backward
- How much you turned left/right
**Robot Observations (24 features)**:
**Robot Observations (12 things)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Orientation
- GPS (latitude, longitude, signal strength)
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp state (on/off)
- Accelerometer (x, y, z)
- Gyroscope (x, y, z)
- Magnetometer (x, y, z)
- Wheel RPMs (4 wheels)
- Lamp status (on/off)
### Where Your Data Goes
+1 -1
View File
@@ -165,7 +165,7 @@ hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
echo $HF_USER
```
-114
View File
@@ -1,114 +0,0 @@
# Rename Map and Empty Cameras
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
## Why observation keys don't always match
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
## Using the rename map
Pass the mapping as a JSON string on the command line. The convention is always:
```
--rename_map='{"source_key": "policy_key", ...}'
```
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
### Training
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--policy.action_mode=auto \
--steps=20000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
```
### Evaluation
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
```bash
lerobot-eval \
--policy.path=lerobot/pi0fast-libero \
--env.type=libero \
... \
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
```
### Recording
`lerobot-record` also supports rename maps, nested under the dataset config:
```bash
lerobot-record \ # When running inference
--policy.path="<user>/smolVLA_finetuned" \
... \
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
```
## Alternative: edit the policy config directly
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
## Empty cameras: fewer views than the policy expects
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
### How it works
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
```
observation.images.empty_camera_0
observation.images.empty_camera_1
...
```
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
### Example
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
1. Set `--policy.empty_cameras=1`.
2. The config adds a third key: `observation.images.empty_camera_0`.
3. Use the rename map for your two real cameras as usual.
4. The third slot is masked out — no fake images needed in your dataset.
## Quick reference
| Goal | What to do |
| ----------------------------------------- | --------------------------------------------------------------------------- |
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
+43 -84
View File
@@ -12,59 +12,36 @@ The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train
## Part 1: Getting Started
### Install the Unitree SDK
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`:
### Install LeRobot on Your Machine
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
pip install -e .
cd ..
```
### Install LeRobot
```bash
conda install ffmpeg -c conda-forge
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
cd unitree_sdk2_python && pip install -e .
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### Test the Installation (Simulation)
The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main).
```bash
pip install mujoco loguru msgpack msgpack-numpy
```
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
--robot.is_simulation=true \
--teleop.type=unitree_g1 \
--teleop.id=wbc_unitree \
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \
--display_data=true \
--robot.controller=GrootLocomotionController
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
--display_data=true
```
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`.
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
- Press `9` to release the robot
- Press `7` / `8` to increase / decrease waist height
### Connect to the Physical Robot
### Connect to the Robot
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
@@ -82,11 +59,37 @@ ssh unitree@192.168.123.164
# Password: 123
```
### Share Internet via Ethernet
### Install LeRobot on the G1
The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet:
From the robot:
**On your laptop:**
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
```
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
---
## Part 2: Enable WiFi on the Robot
Wi-Fi connectivity is blocked by default on the G1. To activate:
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**On your laptop** (share internet via Ethernet):
```bash
sudo sysctl -w net.ipv4.ip_forward=1
@@ -97,7 +100,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
```
**On the G1:**
**On the G1** (set default route through your laptop):
```bash
sudo ip route del default 2>/dev/null || true
@@ -108,45 +111,6 @@ echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
ping -c 3 8.8.8.8
```
### Install the Unitree SDK on the G1
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation):
```bash
conda create -y -n lerobot python=3.12
conda activate lerobot
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python
python -m pip install -e .
cd ..
```
### Install LeRobot on the G1
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
python -m pip install -e '.[unitree_g1]'
```
<Tip>
For now, pinocchio must be installed from conda-forge (not pip) to include the
CasADi bindings needed for arm IK.
</Tip>
### (Optional) Enable WiFi on the Robot
For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default):
```bash
sudo rfkill unblock all
sudo ip link set wlan0 up
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
**Connect to a WiFi network:**
```bash
@@ -161,7 +125,7 @@ sudo nmcli connection up "YourNetwork"
ip a show wlan0
```
You can then SSH over WiFi instead of Ethernet:
You can now SSH over WiFi:
```bash
ssh unitree@<ROBOT_WIFI_IP>
@@ -170,23 +134,18 @@ ssh unitree@<ROBOT_WIFI_IP>
---
## Part 2: Teleoperation & Locomotion
## Part 3: Teleoperation & Locomotion
### Run the Robot Server
On the robot (from `~/lerobot`):
On the robot:
```bash
cd ~/lerobot
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
```
### Run the Locomotion Policy
You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network.
**From your laptop:**
```bash
lerobot-teleoperate \
--robot.type=unitree_g1 \
@@ -199,13 +158,13 @@ lerobot-teleoperate \
--robot.controller=HolosomaLocomotionController
```
We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`.
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
---
## Part 3: Loco-Manipulation with the Homunculus Exoskeleton
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo).
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
### Calibrate
@@ -246,7 +205,7 @@ Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/dat
---
## Part 4: Training & Inference
## Part 5: Training & Inference
### Train
+1 -2
View File
@@ -32,8 +32,7 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
def main():
+1 -1
View File
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import make_default_processors
+1 -1
View File
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
+3 -2
View File
@@ -16,13 +16,15 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import (
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -38,7 +40,6 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -3
View File
@@ -15,11 +15,11 @@
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
@@ -38,7 +38,6 @@ from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+1 -2
View File
@@ -18,7 +18,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -27,7 +27,6 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
+1 -2
View File
@@ -16,7 +16,7 @@
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -31,7 +31,6 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+1 -2
View File
@@ -22,8 +22,7 @@ from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
DROID_SHARDS = 2048
+2 -2
View File
@@ -26,7 +26,7 @@ from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_droid import DROID_SHARDS
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.utils import init_logging
@@ -155,7 +155,7 @@ class UploadDataset(PipelineStep):
from datasets.utils.tqdm import disable_progress_bars
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.utils.utils import init_logging
init_logging()
+1 -2
View File
@@ -113,9 +113,8 @@ from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
+1 -1
View File
@@ -82,7 +82,7 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
+3 -2
View File
@@ -16,13 +16,15 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import (
RobotAction,
RobotObservation,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
@@ -38,7 +40,6 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -3
View File
@@ -16,11 +16,11 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
@@ -35,7 +35,6 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+1 -2
View File
@@ -19,7 +19,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
transition_to_robot_action,
@@ -28,7 +28,6 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
+1 -2
View File
@@ -17,7 +17,7 @@
import time
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
from lerobot.processor.converters import (
robot_action_observation_to_transition,
robot_action_to_transition,
@@ -30,7 +30,6 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+2 -3
View File
@@ -19,9 +19,8 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
+2 -2
View File
@@ -20,9 +20,9 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
@@ -5,9 +5,8 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
+1 -1
View File
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
@@ -5,9 +5,8 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
+1 -1
View File
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
+1 -1
View File
@@ -6,8 +6,8 @@ from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -1,7 +1,7 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
+9 -3
View File
@@ -76,7 +76,7 @@ dependencies = [
"torchvision>=0.21.0,<0.26.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.14.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"pynput>=1.7.8,<1.9.0",
@@ -119,13 +119,14 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"unitree-sdk2==1.0.1",
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0",
"onnx>=1.16.0,<2.0.0",
"pin>=3.0.0,<4.0.0",
"meshcat>=0.3.0,<0.4.0",
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
"casadi>=3.6.0,<4.0.0",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
@@ -174,6 +175,11 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero_plus = [
"lerobot[transformers-dep]",
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
"lerobot[scipy-dep]",
]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
+1 -1
View File
@@ -23,7 +23,7 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ( # noqa: F401
+4 -2
View File
@@ -39,13 +39,15 @@ import grpc
import torch
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
)
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
from lerobot.types import PolicyAction
from .configs import PolicyServerConfig
from .constants import SUPPORTED_POLICIES
-11
View File
@@ -36,16 +36,6 @@ class DatasetConfig:
video_backend: str = field(default_factory=get_safe_default_codec)
streaming: bool = False
def __post_init__(self) -> None:
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
)
if len(self.episodes) != len(set(self.episodes)):
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
@dataclass
class WandBConfig:
@@ -57,7 +47,6 @@ class WandBConfig:
notes: str | None = None
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
add_tags: bool = True # If True, save configuration as tags in the WandB run.
@dataclass
+1 -1
View File
@@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
+1 -1
View File
@@ -51,7 +51,7 @@ class TrainPipelineConfig(HubMixin):
# AND for the evaluation environments.
seed: int | None = 1000
# Set to True to use deterministic cuDNN algorithms for reproducibility.
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
cudnn_deterministic: bool = False
# Number of workers for the dataloader.
num_workers: int = 4
@@ -746,8 +746,7 @@ def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
"""Save annotations to LeRobot dataset parquet format."""
from lerobot.datasets.io_utils import load_episodes
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
@@ -841,7 +840,7 @@ def generate_auto_sparse_annotations(
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
"""Load annotations from LeRobot dataset parquet files."""
from lerobot.datasets.io_utils import load_episodes
from lerobot.datasets.utils import load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
+8 -10
View File
@@ -24,16 +24,7 @@ import pandas as pd
import tqdm
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -41,7 +32,14 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
@@ -0,0 +1,56 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import packaging.version
V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v3.0 which is not backward compatible with v2.1.
Please, update your dataset to the new format using this command:
```
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
```
If you already have a converted version uploaded to the hub, then this error might be because of
an older version in your local cache. Consider deleting the cached version and retrying.
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
if version.major == 2 and version.minor == 1:
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
else:
raise NotImplementedError(
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
+1 -1
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import numpy as np
from lerobot.datasets.io_utils import load_image_as_numpy
from lerobot.datasets.utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
-517
View File
@@ -1,517 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import numpy as np
import packaging.version
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
from lerobot.datasets.io_utils import (
get_file_size_in_mb,
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_json,
write_stats,
write_tasks,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
INFO_PATH,
check_version_compatibility,
flatten_dict,
get_safe_version,
is_valid_version,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
if writer is not None:
writer.close()
self.writer = None
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
@property
def url_root(self) -> str:
return f"hf://datasets/{self.repo_id}"
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
file_idx = ep[f"videos/{vid_key}/file_index"]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
return self.info["robot_type"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
"""All features contained in the dataset."""
return self.info["features"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def chunks_size(self) -> int:
"""Max number of files per chunk."""
return self.info["chunks_size"]
@property
def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"]
@property
def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
if task in self.tasks.index:
return int(self.tasks.loc[task].task_index)
else:
return None
def save_episode_tasks(self, tasks: list[str]):
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode_metadata(self, episode_dict: dict) -> None:
"""Buffer episode metadata and write to parquet in batches for efficiency.
This function accumulates episode metadata in a buffer and flushes it when the buffer
reaches the configured size. This reduces I/O overhead by writing multiple episodes
at once instead of one row at a time.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert to list format for each value
episode_dict = {key: [value] for key, value in episode_dict.items()}
num_frames = episode_dict["length"][0]
if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self.episodes is not None and len(self.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
file_idx = self.episodes[-1]["meta/episodes/file_index"]
latest_num_frames = self.episodes[-1]["dataset_to_index"]
episode_dict["dataset_from_index"] = [latest_num_frames]
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
else:
episode_dict["dataset_from_index"] = [0]
episode_dict["dataset_to_index"] = [num_frames]
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
else:
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
file_idx = self.latest_episode["meta/episodes/file_index"][0]
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
)
if Path(latest_path).exists():
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
latest_num_frames = self.latest_episode["episode_index"][0]
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
# Size limit is reached, flush buffer and prepare new parquet file
self._flush_metadata_buffer()
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
self._close_writer()
# Update the existing pandas dataframe with new row
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
"length": episode_length,
}
episode_dict.update(episode_metadata)
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["total_tasks"] = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
This allows users to customize storage organization without modifying the constructor.
These settings control how episodes are chunked and how large files can grow before
creating new ones.
Args:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size
if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb
if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
def get_chunk_settings(self) -> dict[str, int]:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
}
def __repr__(self):
feature_keys = list(self.features)
return (
f"{self.__class__.__name__}({{\n"
f" Repository ID: '{self.repo_id}',\n"
f" Total episodes: '{self.total_episodes}',\n"
f" Total frames: '{self.total_frames}',\n"
f" Features: '{feature_keys}',\n"
"})',\n"
)
@classmethod
def create(
cls,
repo_id: str,
fps: int,
features: dict,
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
metadata_buffer_size: int = 10,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
CODEBASE_VERSION,
fps,
features,
use_videos,
robot_type,
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError(
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
"Either remove video features from the features dict, or set 'use_videos=True'."
)
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
return obj
+7 -11
View File
@@ -38,22 +38,19 @@ from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.io_utils import (
get_parquet_file_size_in_mb,
load_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
update_chunk_file_indices,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
@@ -918,8 +915,7 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
"""
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import embed_images
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
+5 -3
View File
@@ -20,9 +20,11 @@ import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
-552
View File
@@ -1,552 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pformat
from typing import Any
import datasets
import numpy as np
from PIL import Image as PILImage
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
def get_hf_features_from_features(features: dict) -> datasets.Features:
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
Args:
features (dict): A LeRobot-style feature dictionary.
Returns:
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
Raises:
ValueError: If a feature has an unsupported shape.
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {
key: ftype
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def combine_feature_dicts(*dicts: dict) -> dict:
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,
features: dict,
use_videos: bool,
robot_type: str | None = None,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
Args:
codebase_version (str): The version of the LeRobot codebase.
fps (int): The frames per second of the data.
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
Returns:
dict: A dictionary with the initial dataset metadata.
"""
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
"fps": fps,
"splits": {},
"data_path": DEFAULT_DATA_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
This ensures that adding these delta timestamps to any existing timestamp in
the dataset will result in a value that aligns with the dataset's frame rate.
Args:
delta_timestamps (dict): A dictionary where values are lists of time
deltas in seconds.
fps (int): The frames per second of the dataset.
tolerance_s (float): The allowed tolerance in seconds.
raise_value_error (bool): If True, raises an error on failure.
Returns:
bool: True if all deltas are valid, False otherwise.
Raises:
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
"""Convert delta timestamps in seconds to delta indices in frames.
Args:
delta_timestamps (dict): A dictionary of time deltas in seconds.
fps (int): The frames per second of the dataset.
Returns:
dict: A dictionary of frame delta indices.
"""
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
return delta_indices
def validate_frame(frame: dict, features: dict) -> None:
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
# task is a special required field that's not part of regular features
if "task" not in actual_features:
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
# Remove task from actual_features for regular feature validation
actual_features_for_validation = actual_features - {"task"}
error_message = validate_features_presence(actual_features_for_validation, expected_features)
common_features = actual_features_for_validation & expected_features
for name in common_features:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
"""Check for missing or extra features in a frame.
Args:
actual_features (set[str]): The set of feature names present in the frame.
expected_features (set[str]): The set of feature names expected in the frame.
Returns:
str: An error message string if there's a mismatch, otherwise an empty string.
"""
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
) -> str:
"""Validate the dtype and shape of a single feature's value.
Args:
name (str): The name of the feature.
feature (dict): The feature specification from the LeRobot features dictionary.
value: The value of the feature to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
Raises:
NotImplementedError: If the feature dtype is not supported for validation.
"""
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
) -> str:
"""Validate a feature that is expected to be a numpy array.
Args:
name (str): The name of the feature.
expected_dtype (str): The expected numpy dtype as a string.
expected_shape (list[int]): The expected shape.
value (np.ndarray): The numpy array to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
) -> str:
"""Validate a feature that is expected to be an image or video frame.
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str) -> str:
"""Validate a feature that is expected to be a string.
Args:
name (str): The name of the feature.
value (str): The value to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
"""Validate the episode buffer before it's written to disk.
Ensures the buffer has the required keys, contains at least one frame, and
has features consistent with the dataset's specification.
Args:
episode_buffer (dict): The buffer containing data for a single episode.
total_episodes (int): The current total number of episodes in the dataset.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the buffer is invalid.
NotImplementedError: If the episode index is manually set and doesn't match.
"""
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
+4 -6
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import multiprocessing
import queue
import threading
@@ -23,8 +22,6 @@ import numpy as np
import PIL.Image
import torch
logger = logging.getLogger(__name__)
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
@@ -34,7 +31,7 @@ def safe_stop_image_writer(func):
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
logger.warning("Waiting for image writer to terminate...")
print("Waiting for image writer to terminate...")
image_writer.stop()
raise e
@@ -92,7 +89,8 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
PIL.Image.Image object.
Side Effects:
Logs an error message if the image writing process fails for any reason.
Prints an error message to the console if the image writing process
fails for any reason.
"""
try:
if isinstance(image, np.ndarray):
@@ -103,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)
print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue):
-342
View File
@@ -1,342 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pathlib import Path
from typing import Any
import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from datasets.table import embed_table_storage
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.datasets.utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
STATS_PATH,
flatten_dict,
serialize_dict,
unflatten_dict,
)
from lerobot.utils.utils import SuppressProgressBars
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
total_uncompressed_size = 0
for row_group in range(metadata.num_row_groups):
rg_metadata = metadata.row_group(row_group)
for column in range(rg_metadata.num_columns):
col_metadata = rg_metadata.column(column)
total_uncompressed_size += col_metadata.total_uncompressed_size
return total_uncompressed_size / (1024**2)
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
return hf_ds.data.nbytes // (1024**2)
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
with SuppressProgressBars():
# We use .from_parquet() memory-mapped loading for efficiency
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
def get_parquet_num_frames(parquet_path: str | Path) -> int:
metadata = pq.read_metadata(parquet_path)
return metadata.num_rows
def get_file_size_in_mb(file_path: Path) -> float:
"""Get file size on disk in megabytes.
Args:
file_path (Path): Path to the file.
"""
file_size_bytes = file_path.stat().st_size
return file_size_bytes / (1024**2)
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
"""Embed image bytes into the dataset table before saving to Parquet.
This function prepares a Hugging Face dataset for serialization by converting
image objects into an embedded format that can be stored in Arrow/Parquet.
Args:
dataset (datasets.Dataset): The input dataset, possibly containing image features.
Returns:
datasets.Dataset: The dataset with images embedded in the table storage.
"""
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
return dataset
def load_json(fpath: Path) -> Any:
"""Load data from a JSON file.
Args:
fpath (Path): Path to the JSON file.
Returns:
Any: The data loaded from the JSON file.
"""
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
"""Write data to a JSON file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to write.
fpath (Path): The path to the output JSON file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
"""Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: The dataset information dictionary.
"""
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def write_stats(stats: dict, local_dir: Path) -> None:
"""Serialize and write dataset statistics to their standard file path.
Args:
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
local_dir (Path): The root directory of the dataset.
"""
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
Args:
stats (dict): The statistics dictionary.
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
"""Load dataset statistics and cast numerical values to numpy arrays.
Returns None if the stats file doesn't exist.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A dictionary of statistics or None if the file is not found.
"""
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
path = local_dir / DEFAULT_TASKS_PATH
path.parent.mkdir(parents=True, exist_ok=True)
tasks.to_parquet(path)
def load_tasks(local_dir: Path) -> pandas.DataFrame:
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
tasks.index.name = "task"
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
Used primarily during dataset conversion (v2.1 v3.0) and in test fixtures.
Args:
episodes: HuggingFace Dataset containing episode metadata
local_dir: Root directory where the dataset will be stored
"""
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
raise NotImplementedError(
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
"This function only supports single-file episode metadata. "
)
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
fpath.parent.mkdir(parents=True, exist_ok=True)
episodes.to_parquet(fpath)
def load_episodes(local_dir: Path) -> datasets.Dataset:
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
# Select episode features/columns containing references to episode data and videos
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
# This is to speedup access to these data, instead of having to load episode stats.
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
return episodes
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
"""Load an image from a file into a numpy array.
Args:
fpath (str | Path): Path to the image file.
dtype (np.dtype): The desired data type of the output array. If floating,
pixels are scaled to [0, 1].
channel_first (bool): If True, converts the image to (C, H, W) format.
Otherwise, it remains in (H, W, C) format.
Returns:
np.ndarray: The image as a numpy array.
"""
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict:
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
This function is used to convert an item from a streaming dataset to PyTorch tensors.
Args:
item (dict): Dictionary of items from a dataset.
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
+679 -23
View File
@@ -23,52 +23,526 @@ from pathlib import Path
import datasets
import numpy as np
import packaging.version
import pandas as pd
import PIL.Image
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.datasets.compute_stats import compute_episode_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
get_delta_indices,
get_hf_features_from_features,
validate_episode_buffer,
validate_frame,
)
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.io_utils import (
embed_images,
get_file_size_in_mb,
hf_transform_to_torch,
load_episodes,
load_nested_dataset,
write_info,
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
INFO_PATH,
_validate_feature_names,
check_delta_timestamps,
check_version_compatibility,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
flatten_dict,
get_delta_indices,
get_file_size_in_mb,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
write_info,
write_json,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import (
StreamingVideoEncoder,
VideoFrame,
concatenate_video_files,
decode_video_frames,
encode_video_frames,
get_safe_default_codec,
get_video_duration_in_s,
get_video_info,
resolve_vcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
metadata_buffer_size: int = 10,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
self.writer = None
self.latest_episode = None
self.metadata_buffer: list[dict] = []
self.metadata_buffer_size = metadata_buffer_size
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def _flush_metadata_buffer(self) -> None:
"""Write all buffered episode metadata to parquet file."""
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
return
combined_dict = {}
for episode_dict in self.metadata_buffer:
for key, value in episode_dict.items():
if key not in combined_dict:
combined_dict[key] = []
# Extract value and serialize numpy arrays
# because PyArrow's from_pydict function doesn't support numpy arrays
val = value[0] if isinstance(value, list) else value
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
first_ep = self.metadata_buffer[0]
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
file_idx = first_ep["meta/episodes/file_index"][0]
table = pa.Table.from_pydict(combined_dict)
if not self.writer:
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
path.parent.mkdir(parents=True, exist_ok=True)
self.writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True
)
self.writer.write_table(table)
self.latest_episode = self.metadata_buffer[-1]
self.metadata_buffer.clear()
def _close_writer(self) -> None:
"""Close and cleanup the parquet writer if it exists."""
self._flush_metadata_buffer()
writer = getattr(self, "writer", None)
if writer is not None:
writer.close()
self.writer = None
def __del__(self):
"""
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
"""
self._close_writer()
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
@property
def url_root(self) -> str:
return f"hf://datasets/{self.repo_id}"
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep["data/chunk_index"]
file_idx = ep["data/file_index"]
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
if self.episodes is None:
self.episodes = load_episodes(self.root)
if ep_index >= len(self.episodes):
raise IndexError(
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
)
ep = self.episodes[ep_index]
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
file_idx = ep[f"videos/{vid_key}/file_index"]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["video_path"]
@property
def robot_type(self) -> str | None:
"""Robot type used in recording this dataset."""
return self.info["robot_type"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
"""All features contained in the dataset."""
return self.info["features"]
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]
@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]
@property
def chunks_size(self) -> int:
"""Max number of files per chunk."""
return self.info["chunks_size"]
@property
def data_files_size_in_mb(self) -> int:
"""Max size of data file in mega bytes."""
return self.info["data_files_size_in_mb"]
@property
def video_files_size_in_mb(self) -> int:
"""Max size of video file in mega bytes."""
return self.info["video_files_size_in_mb"]
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise return None.
"""
if task in self.tasks.index:
return int(self.tasks.loc[task].task_index)
else:
return None
def save_episode_tasks(self, tasks: list[str]):
if len(set(tasks)) != len(tasks):
raise ValueError(f"Tasks are not unique: {tasks}")
if self.tasks is None:
new_tasks = tasks
task_indices = range(len(tasks))
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
if len(new_tasks) > 0:
# Update on disk
write_tasks(self.tasks, self.root)
def _save_episode_metadata(self, episode_dict: dict) -> None:
"""Buffer episode metadata and write to parquet in batches for efficiency.
This function accumulates episode metadata in a buffer and flushes it when the buffer
reaches the configured size. This reduces I/O overhead by writing multiple episodes
at once instead of one row at a time.
Notes: We both need to update parquet files and HF dataset:
- `pandas` loads parquet file in RAM
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
or loads directly from pyarrow cache.
"""
# Convert to list format for each value
episode_dict = {key: [value] for key, value in episode_dict.items()}
num_frames = episode_dict["length"][0]
if self.latest_episode is None:
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
if self.episodes is not None and len(self.episodes) > 0:
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
file_idx = self.episodes[-1]["meta/episodes/file_index"]
latest_num_frames = self.episodes[-1]["dataset_to_index"]
episode_dict["dataset_from_index"] = [latest_num_frames]
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
else:
episode_dict["dataset_from_index"] = [0]
episode_dict["dataset_to_index"] = [num_frames]
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
else:
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
file_idx = self.latest_episode["meta/episodes/file_index"][0]
latest_path = (
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
if self.writer is None
else self.writer.where
)
if Path(latest_path).exists():
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
latest_num_frames = self.latest_episode["episode_index"][0]
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
# Size limit is reached, flush buffer and prepare new parquet file
self._flush_metadata_buffer()
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
self._close_writer()
# Update the existing pandas dataframe with new row
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
episode_dict["meta/episodes/file_index"] = [file_idx]
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
# Add to buffer
self.metadata_buffer.append(episode_dict)
self.latest_episode = episode_dict
if len(self.metadata_buffer) >= self.metadata_buffer_size:
self._flush_metadata_buffer()
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
"length": episode_length,
}
episode_dict.update(episode_metadata)
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["total_tasks"] = len(self.tasks)
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
write_stats(self.stats, self.root)
def update_video_info(self, video_key: str | None = None) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
if video_key is not None and video_key not in self.video_keys:
raise ValueError(f"Video key {video_key} not found in dataset")
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
self,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> None:
"""Update chunk and file size settings after dataset creation.
This allows users to customize storage organization without modifying the constructor.
These settings control how episodes are chunked and how large files can grow before
creating new ones.
Args:
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
"""
if chunks_size is not None:
if chunks_size <= 0:
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
self.info["chunks_size"] = chunks_size
if data_files_size_in_mb is not None:
if data_files_size_in_mb <= 0:
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
self.info["data_files_size_in_mb"] = data_files_size_in_mb
if video_files_size_in_mb is not None:
if video_files_size_in_mb <= 0:
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
self.info["video_files_size_in_mb"] = video_files_size_in_mb
# Update the info file on disk
write_info(self.info, self.root)
def get_chunk_settings(self) -> dict[str, int]:
"""Get current chunk and file size settings.
Returns:
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
"""
return {
"chunks_size": self.chunks_size,
"data_files_size_in_mb": self.data_files_size_in_mb,
"video_files_size_in_mb": self.video_files_size_in_mb,
}
def __repr__(self):
feature_keys = list(self.features)
return (
f"{self.__class__.__name__}({{\n"
f" Repository ID: '{self.repo_id}',\n"
f" Total episodes: '{self.total_episodes}',\n"
f" Total frames: '{self.total_frames}',\n"
f" Features: '{feature_keys}',\n"
"})',\n"
)
@classmethod
def create(
cls,
repo_id: str,
fps: int,
features: dict,
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
metadata_buffer_size: int = 10,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
CODEBASE_VERSION,
fps,
features,
use_videos,
robot_type,
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
obj.latest_episode = None
obj.metadata_buffer = []
obj.metadata_buffer_size = metadata_buffer_size
return obj
def _encode_video_worker(
@@ -122,7 +596,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
the dataset from that address and load it, pending your dataset is compliant with
codebase_version v3.0. If your dataset has been created before this new format, you will be
prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at
lerobot/scripts/convert_dataset_v21_to_v30.py.
lerobot/datasets/v30/convert_dataset_v21_to_v30.py.
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
@@ -852,7 +1326,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logger.error(f"Video encoding failed for {video_key}: {exc}")
logging.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self.meta.video_keys:
@@ -891,7 +1365,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if end_episode is None:
end_episode = self.num_episodes
logger.info(
logging.info(
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
@@ -901,7 +1375,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logger.info(f"Encoding videos for episode {ep_idx}")
logging.info(f"Encoding videos for episode {ep_idx}")
if (
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
@@ -1131,7 +1605,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logger.warning(
logging.warning(
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
@@ -1209,6 +1683,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
obj.episodes = None
@@ -1242,3 +1717,184 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._streaming_encoder = None
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Number of Samples: {self.num_frames},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
-210
View File
@@ -1,210 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Callable
from pathlib import Path
import datasets
import torch
import torch.utils
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import VideoFrame
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logger.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Number of Samples: {self.num_frames},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
+382
View File
@@ -0,0 +1,382 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An online buffer for the online training loop in train.py
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
supports in-place slicing and mutation which is very handy for a dynamic buffer.
"""
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def _make_memmap_safe(**kwargs) -> np.memmap:
"""Make a numpy memmap with checks on available disk space first.
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
For information on dtypes:
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
"""
if kwargs["mode"].startswith("w"):
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
stats = os.statvfs(Path(kwargs["filename"]).parent)
available_space = stats.f_bavail * stats.f_frsize # bytes
if required_space >= available_space * 0.8:
raise RuntimeError(
f"You're about to take up {required_space} of {available_space} bytes available."
)
return np.memmap(**kwargs)
class OnlineBuffer(torch.utils.data.Dataset):
"""FIFO data buffer for the online training loop in train.py.
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
loop in the same way that a LeRobotDataset would be used.
The underlying data structure will have data inserted in a circular fashion. Always insert after the
last index, and when you reach the end, wrap around to the start.
The data is stored in a numpy memmap.
"""
NEXT_INDEX_KEY = "_next_index"
OCCUPANCY_MASK_KEY = "_occupancy_mask"
INDEX_KEY = "index"
FRAME_INDEX_KEY = "frame_index"
EPISODE_INDEX_KEY = "episode_index"
TIMESTAMP_KEY = "timestamp"
IS_PAD_POSTFIX = "_is_pad"
def __init__(
self,
write_dir: str | Path,
data_spec: dict[str, Any] | None,
buffer_capacity: int | None,
fps: float | None = None,
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
):
"""
The online buffer can be provided from scratch or you can load an existing online buffer by passing
a `write_dir` associated with an existing buffer.
Args:
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
Note that if the files already exist, they are opened in read-write mode (used for training
resumption.)
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
but note that "index", "frame_index" and "episode_index" are already accounted for by this
class, so you don't need to include them.
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
system's available disk space when choosing this.
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
converted to dict[str, np.ndarray] for optimization purposes.
"""
self.set_delta_timestamps(delta_timestamps)
self._fps = fps
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
# the requested frames. It is only used when `delta_timestamps` is provided.
# minus 1e-4 to account for possible numerical error
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
self._buffer_capacity = buffer_capacity
data_spec = self._make_data_spec(data_spec, buffer_capacity)
Path(write_dir).mkdir(parents=True, exist_ok=True)
self._data = {}
for k, v in data_spec.items():
self._data[k] = _make_memmap_safe(
filename=Path(write_dir) / k,
dtype=v["dtype"] if v is not None else None,
mode="r+" if (Path(write_dir) / k).exists() else "w+",
shape=tuple(v["shape"]) if v is not None else None,
)
@property
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
return self._delta_timestamps
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
"""Set delta_timestamps converting the values to numpy arrays.
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
need to be converted into numpy arrays.
"""
if value is not None:
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
)
preset_keys = {
OnlineBuffer.INDEX_KEY,
OnlineBuffer.FRAME_INDEX_KEY,
OnlineBuffer.EPISODE_INDEX_KEY,
OnlineBuffer.TIMESTAMP_KEY,
}
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
raise ValueError(
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
f"The provided data_spec has {intersection}."
)
complete_data_spec = {
# _next_index will be a pointer to the next index that we should start filling from when we add
# more data.
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
"""Add new data to the buffer, which could potentially mean shifting old data out.
The new data should contain all the frames (in order) of any number of episodes. The indices should
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
will be done in place!
"""
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
raise ValueError(f"Missing data keys: {missing_keys}")
new_data_length = len(data[self.data_keys[0]])
if not all(len(data[k]) == new_data_length for k in self.data_keys):
raise ValueError("All data items should have the same length")
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
# Sanity check to make sure that the new data indices start from 0.
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
for k in self.data_keys:
if n_surplus == 0:
slc = slice(next_index, next_index + new_data_length)
self._data[k][slc] = data[k]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
else:
self._data[k][next_index:] = data[k][:-n_surplus]
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
self._data[k][:n_surplus] = data[k][-n_surplus:]
if n_surplus == 0:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
else:
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
@property
def data_keys(self) -> list[str]:
keys = set(self._data)
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
return sorted(keys)
@property
def fps(self) -> float | None:
return self._fps
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
)
@property
def num_frames(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_frames
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}
for k, v in item.items():
if isinstance(v, torch.Tensor):
item_[k] = v
elif isinstance(v, np.ndarray):
item_[k] = torch.from_numpy(v)
else:
item_[k] = torch.tensor(v)
return item_
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self) or idx < -len(self):
raise IndexError
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
if self.delta_timestamps is None:
return self._item_to_tensors(item)
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
episode_data_indices = np.where(
np.bitwise_and(
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
# Get timestamps used as query to retrieve data of previous/future frames.
query_ts = current_ts + self.delta_timestamps[data_key]
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
# the episode.
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
argmin_ = np.argmin(dist, axis=1)
min_ = dist[np.arange(dist.shape[0]), argmin_]
is_pad = min_ > self.tolerance_s
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
)
# Load frames for this data key.
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
return self._item_to_tensors(item)
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
def compute_sampler_weights(
offline_dataset: LeRobotDataset,
offline_drop_n_last_frames: int = 0,
online_dataset: OnlineBuffer | None = None,
online_sampling_ratio: float | None = None,
online_drop_n_last_frames: int = 0,
) -> torch.Tensor:
"""Compute the sampling weights for the online training dataloader in train.py.
Args:
offline_dataset: The LeRobotDataset used for offline pre-training.
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
online_dataset: The OnlineBuffer used in online training.
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
online dataset is provided, this value must also be provided.
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
dataset.
Returns:
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
Notes to maintainers:
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
is the ability to turn shuffling off.
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
weights = []
if len(offline_dataset) > 0:
offline_data_mask_indices = []
for start_index, end_index in zip(
offline_dataset.meta.episodes["dataset_from_index"],
offline_dataset.meta.episodes["dataset_to_index"],
strict=True,
):
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(offline_dataset),),
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
)
* offline_data_mask
)
if online_dataset is not None and len(online_dataset) > 0:
online_data_mask_indices = []
episode_indices = online_dataset.get_data_by_key("episode_index")
for episode_idx in torch.unique(episode_indices):
where_episode = torch.where(episode_indices == episode_idx)
start_index = where_episode[0][0]
end_index = where_episode[0][-1] + 1
online_data_mask_indices.extend(
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
)
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
online_data_mask[torch.tensor(online_data_mask_indices)] = True
weights.append(
torch.full(
size=(len(online_dataset),),
fill_value=online_sampling_ratio / online_data_mask.sum(),
)
* online_data_mask
)
weights = torch.cat(weights)
if weights.sum() == 0:
weights += 1 / len(weights)
else:
weights /= weights.sum()
return weights
+6 -9
View File
@@ -17,9 +17,8 @@ from collections.abc import Sequence
from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
from lerobot.types import RobotAction, RobotObservation
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
@@ -44,11 +43,11 @@ def create_initial_features(
return features
# Helper to filter state/action keys based on compiled regex patterns.
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
# Helper to filter state/action keys based on regex patterns.
def should_keep(key: str, patterns: tuple[str]) -> bool:
if patterns is None:
return True
return any(pat.search(key) for pat in patterns)
return any(re.search(pat, key) for pat in patterns)
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
@@ -89,8 +88,6 @@ def aggregate_pipeline_dataset_features(
Returns:
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
"""
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
all_features = pipeline.transform_features(initial_features)
# Intermediate storage for categorized and filtered features.
@@ -122,7 +119,7 @@ def aggregate_pipeline_dataset_features(
# 2. Apply filtering rules.
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, compiled_patterns):
if not is_image and not should_keep(key, patterns):
continue
# 3. Add the feature to the appropriate group with a clean name.
@@ -0,0 +1,73 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datasets
import torch
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
-25
View File
@@ -13,13 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Iterator
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
def __init__(
@@ -42,35 +39,13 @@ class EpisodeAwareSampler:
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self.shuffle = shuffle
+7 -163
View File
@@ -13,8 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from collections.abc import Callable, Generator, Iterable, Iterator
from collections.abc import Callable, Generator, Iterator
from pathlib import Path
import datasets
@@ -22,13 +21,16 @@ import numpy as np
import torch
from datasets import load_dataset
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_delta_indices
from lerobot.datasets.io_utils import item_to_torch
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
Backtrackable,
LookAheadError,
LookBackError,
check_version_compatibility,
find_float_index,
get_delta_indices,
is_float_in_list,
item_to_torch,
safe_shard,
)
from lerobot.datasets.video_utils import (
@@ -38,164 +40,6 @@ from lerobot.datasets.video_utils import (
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
class LookBackError(Exception):
"""
Exception raised when trying to look back in the history of a Backtrackable object.
"""
pass
class LookAheadError(Exception):
"""
Exception raised when trying to look ahead in the future of a Backtrackable object.
"""
pass
class Backtrackable[T]:
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
This is useful for streaming datasets where you need to access previous and future items
but can't load the entire dataset into memory.
Example:
-------
```python
ds = load_dataset("c4", "en", streaming=True, split="train")
rev = Backtrackable(ds, history=3, lookahead=2)
x0 = next(rev) # forward
x1 = next(rev)
x2 = next(rev)
# Look ahead
x3_peek = rev.peek_ahead(1) # next item without moving cursor
x4_peek = rev.peek_ahead(2) # two items ahead
# Look back
x1_again = rev.peek_back(1) # previous item without moving cursor
x0_again = rev.peek_back(2) # two items back
# Move backward
x1_back = rev.prev() # back one step
next(rev) # returns x2, continues forward from where we were
```
"""
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
if history < 1:
raise ValueError("history must be >= 1")
if lookahead <= 0:
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
def __iter__(self) -> "Backtrackable[T]":
return self
def __next__(self) -> T:
# If we've stepped back, consume from back buffer first
if self._cursor < 0: # -1 means "last item", etc.
self._cursor += 1
return self._back_buf[self._cursor]
# If we have items in the ahead buffer, use them first
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
# Add current item to back buffer and reset cursor
self._back_buf.append(item)
self._cursor = 0
return item
def prev(self) -> T:
"""
Step one item back in history and return it.
Raises IndexError if already at the oldest buffered item.
"""
if len(self._back_buf) + self._cursor <= 1:
raise LookBackError("At start of history")
self._cursor -= 1
return self._back_buf[self._cursor]
def peek_back(self, n: int = 1) -> T:
"""
Look `n` items back (n=1 == previous item) without moving the cursor.
"""
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
raise LookBackError("peek_back distance out of range")
return self._back_buf[self._cursor - (n + 1)]
def peek_ahead(self, n: int = 1) -> T:
"""
Look `n` items ahead (n=1 == next item) without moving the cursor.
Fills the ahead buffer if necessary.
"""
if n < 1:
raise LookAheadError("peek_ahead distance must be 1 or more")
elif n > self._lookahead:
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
# Fill ahead buffer if we don't have enough items
while len(self._ahead_buf) < n:
try:
item = next(self._source)
self._ahead_buf.append(item)
except StopIteration as err:
raise LookAheadError("peek_ahead: not enough items in source") from err
return self._ahead_buf[n - 1]
def history(self) -> list[T]:
"""
Return a copy of the buffered history (most recent last).
The list length `history` argument passed at construction.
"""
if self._cursor == 0:
return list(self._back_buf)
# When cursor<0, slice so the order remains chronological
return list(self._back_buf)[: self._cursor or None]
def can_peek_back(self, steps: int = 1) -> bool:
"""
Check if we can go back `steps` items without raising an IndexError.
"""
return steps <= len(self._back_buf) + self._cursor
def can_peek_ahead(self, steps: int = 1) -> bool:
"""
Check if we can peek ahead `steps` items.
This may involve trying to fill the ahead buffer.
"""
if self._lookahead > 0 and steps > self._lookahead:
return False
# Try to fill ahead buffer to check if we can peek that far
try:
while len(self._ahead_buf) < steps:
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
return False
item = next(self._source)
self._ahead_buf.append(item)
return True
except StopIteration:
return False
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""LeRobotDataset with streaming capabilities.
File diff suppressed because it is too large Load Diff
@@ -28,7 +28,7 @@ quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
Usage:
```bash
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=lerobot/pusht \
```
"""
@@ -45,9 +45,8 @@ from requests import HTTPError
from tqdm import tqdm
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
from lerobot.datasets.io_utils import write_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import write_stats
from lerobot.utils.utils import init_logging
@@ -28,13 +28,13 @@ Usage:
Convert a dataset from the hub:
```bash
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht
```
Convert a local dataset (works in place):
```bash
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht \
--root=/path/to/local/dataset/directory \
--push-to-hub=false
@@ -60,19 +60,7 @@ from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
from lerobot.datasets.io_utils import (
cast_stats_to_numpy,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
load_info,
write_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -82,8 +70,17 @@ from lerobot.datasets.utils import (
LEGACY_EPISODES_PATH,
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
flatten_dict,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
get_parquet_num_frames,
load_info,
update_chunk_file_indices,
write_episodes,
write_info,
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
+23 -25
View File
@@ -37,8 +37,6 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
@@ -96,7 +94,7 @@ def detect_available_hw_encoders() -> list[str]:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
pass # nosec B110
return available
@@ -105,14 +103,14 @@ def resolve_vcodec(vcodec: str) -> str:
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logger.info(f"Using video codec: {vcodec}")
logging.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logger.info(f"Auto-selected video codec: {encoder}")
logging.info(f"Auto-selected video codec: {encoder}")
return encoder
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
@@ -120,7 +118,7 @@ def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logger.warning(
logging.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
@@ -210,7 +208,7 @@ def decode_video_frames_torchvision(
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
@@ -246,7 +244,7 @@ def decode_video_frames_torchvision(
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
logging.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
@@ -350,7 +348,7 @@ def decode_video_frames_torchcodec(
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logger.info(f"Frame loaded at timestamp={pts:.4f}")
logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
@@ -376,7 +374,7 @@ def decode_video_frames_torchcodec(
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logger.info(f"{closest_ts=}")
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
@@ -410,14 +408,14 @@ def encode_video_frames(
imgs_dir = Path(imgs_dir)
if video_path.exists() and not overwrite:
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
return
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logger.warning(
logging.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
@@ -510,7 +508,7 @@ def concatenate_video_files(
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
@@ -695,7 +693,7 @@ class _CameraEncoderThread(threading.Thread):
self.result_queue.put(("ok", None))
except Exception as e:
logger.error(f"Encoder thread error: {e}")
logging.error(f"Encoder thread error: {e}")
if container is not None:
with contextlib.suppress(Exception):
container.close()
@@ -821,7 +819,7 @@ class StreamingVideoEncoder:
count = self._dropped_frames[video_key]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0:
logger.warning(
logging.warning(
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
)
@@ -843,7 +841,7 @@ class StreamingVideoEncoder:
# Report dropped frames
for video_key, count in self._dropped_frames.items():
if count > 0:
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
# Send sentinel to all queues
for video_key in self._frame_queues:
@@ -853,7 +851,7 @@ class StreamingVideoEncoder:
for video_key in self._threads:
self._threads[video_key].join(timeout=120)
if self._threads[video_key].is_alive():
logger.error(f"Encoder thread for {video_key} did not finish in time")
logging.error(f"Encoder thread for {video_key} did not finish in time")
self._stop_events[video_key].set()
self._threads[video_key].join(timeout=5)
results[video_key] = (self._video_paths[video_key], None)
@@ -865,7 +863,7 @@ class StreamingVideoEncoder:
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
results[video_key] = (self._video_paths[video_key], data)
except queue.Empty:
logger.error(f"No result from encoder thread for {video_key}")
logging.error(f"No result from encoder thread for {video_key}")
results[video_key] = (self._video_paths[video_key], None)
self._cleanup()
@@ -1073,13 +1071,13 @@ class VideoEncodingManager:
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logger.info("Exception occurred. Encoding remaining episodes before exit...")
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logger.info("Recording stopped. Encoding remaining episodes...")
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logger.info(
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
@@ -1096,7 +1094,7 @@ class VideoEncodingManager:
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logger.debug(
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
@@ -1107,8 +1105,8 @@ class VideoEncodingManager:
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
shutil.rmtree(img_dir)
logger.debug("Cleaned up empty images directory")
logging.debug("Cleaned up empty images directory")
else:
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
+59
View File
@@ -346,6 +346,65 @@ class LiberoEnv(EnvConfig):
return kwargs
@EnvConfig.register_subclass("libero_plus")
@dataclass
class LiberoPlusEnv(LiberoEnv):
"""Alias config for LIBERO-plus benchmarks.
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
config reuses the existing LIBERO env implementation while making intent explicit
in experiment configs (`env.type=libero_plus`).
"""
task: str = "libero_spatial"
@EnvConfig.register_subclass("robocasa")
@dataclass
class RoboCasaEnv(EnvConfig):
"""RoboCasa kitchen composite-task environments.
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
action space and a structured pixel + state observation dict.
Selected benchmark tasks (3 short + 2 long):
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
Long: PrepareCoffee, RestockPantry
"""
task: str = "PickPlaceCounterToCabinet"
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
fps: int = 20
episode_length: int = 500
image_size: int = 128
split: str = "target" # "pretrain" or "target"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agentview_left": f"{OBS_IMAGES}.agentview_left",
"agentview_right": f"{OBS_IMAGES}.agentview_right",
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
"robot_state": OBS_STATE,
}
)
def __post_init__(self):
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
self.features[cam] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
)
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {"split": self.split}
@EnvConfig.register_subclass("metaworld")
@dataclass
class MetaworldEnv(EnvConfig):
+34 -3
View File
@@ -20,11 +20,20 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.configs import (
AlohaEnv,
EnvConfig,
HubEnvConfig,
IsaaclabArenaEnv,
LiberoEnv,
LiberoPlusEnv,
PushtEnv,
RoboCasaEnv,
)
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -35,6 +44,10 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
elif env_type == "libero_plus":
return LiberoPlusEnv(**kwargs)
elif env_type == "robocasa":
return RoboCasaEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -70,9 +83,13 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
preprocessor_steps.append(RoboCasaProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
@@ -181,6 +198,20 @@ def make_env(
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "robocasa" in cfg.type:
from lerobot.envs.robocasa import create_robocasa_envs
tasks = cfg.tasks if cfg.tasks else [cfg.task]
return create_robocasa_envs(
tasks=tasks,
n_envs=n_envs,
image_size=cfg.image_size,
split=cfg.split,
episode_length=cfg.episode_length,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
+9 -3
View File
@@ -26,10 +26,16 @@ import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from lerobot.types import RobotObservation
try:
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
except ImportError:
# LIBERO-plus may be installed from source with an extra nested package level.
from libero.libero.libero import benchmark, get_libero_path
from libero.libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
+1 -1
View File
@@ -25,7 +25,7 @@ import metaworld.policies as policies
import numpy as np
from gymnasium import spaces
from lerobot.types import RobotObservation
from lerobot.processor import RobotObservation
# ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
+273
View File
@@ -0,0 +1,273 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Action layout (flat 12D, normalized to [-1, 1]):
# [0:3] end_effector_position (delta x, y, z)
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
# [6:7] gripper_close (open=-1, close=+1)
# [7:11] base_motion (x, y, theta, torso_height)
# [11:12] control_mode (arm=-1, base=+1)
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
# Proprioceptive state layout (flat 16D):
# [0:2] gripper_qpos
# [2:5] base_position
# [5:9] base_rotation (quaternion)
# [9:12] end_effector_position_relative
# [12:16] end_effector_rotation_relative (quaternion)
STATE_DIM = 16
# Obs dict keys from RoboCasaGymEnv.get_observation()
_CAM_KEYS = (
"video.robot0_agentview_left",
"video.robot0_agentview_right",
"video.robot0_eye_in_hand",
)
_STATE_KEYS_ORDERED = (
"state.gripper_qpos", # (2,)
"state.base_position", # (3,)
"state.base_rotation", # (4,)
"state.end_effector_position_relative", # (3,)
"state.end_effector_rotation_relative", # (4,)
)
# Mapping from video.* key → short image name used in features_map
CAM_KEY_TO_NAME = {
"video.robot0_agentview_left": "agentview_left",
"video.robot0_agentview_right": "agentview_right",
"video.robot0_eye_in_hand": "eye_in_hand",
}
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
return {
"action.end_effector_position": flat[0:3],
"action.end_effector_rotation": flat[3:6],
"action.gripper_close": flat[6:7],
"action.base_motion": flat[7:11],
"action.control_mode": flat[11:12],
}
class RoboCasaEnv(gym.Env):
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
and a structured observation dict compatible with LeRobot policies.
Observations returned by step/reset:
{
"pixels": {
"agentview_left": (H, W, 3) uint8,
"agentview_right": (H, W, 3) uint8,
"eye_in_hand": (H, W, 3) uint8,
},
"robot_state": (16,) float32,
}
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
"""
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
def __init__(
self,
task: str,
split: str = "target",
image_size: int = 128,
render_mode: str = "rgb_array",
episode_length: int = 500,
**gym_kwargs: Any,
):
super().__init__()
# Lazy import — robocasa is optional
import robocasa.environments # noqa: F401 — registers all gym envs
self.task = task
self.render_mode = render_mode
self.image_size = image_size
self._max_episode_steps = episode_length
self._step_count = 0
self._env = gym.make(
f"robocasa/{task}",
split=split,
camera_widths=image_size,
camera_heights=image_size,
**gym_kwargs,
)
# Flat 12D Box action space
self.action_space = spaces.Box(
low=ACTION_LOW,
high=ACTION_HIGH,
shape=(ACTION_DIM,),
dtype=np.float32,
)
images = {
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
for name in CAM_KEY_TO_NAME.values()
}
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Box(
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
),
}
)
def _format_obs(self, raw_obs: dict) -> dict:
pixels = {
CAM_KEY_TO_NAME[k]: raw_obs[k]
for k in _CAM_KEYS
if k in raw_obs
}
state_parts = [
np.asarray(raw_obs[k], dtype=np.float32)
for k in _STATE_KEYS_ORDERED
if k in raw_obs
]
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
return {"pixels": pixels, "robot_state": robot_state}
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
super().reset(seed=seed)
self._step_count = 0
raw_obs, info = self._env.reset(seed=seed)
info.setdefault("is_success", False)
info["task"] = self.task
return self._format_obs(raw_obs), info
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
)
action_dict = _flat_to_action_dict(action)
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
self._step_count += 1
is_success = bool(info.get("success", False))
terminated = terminated or is_success
if self._step_count >= self._max_episode_steps:
truncated = True
info.update({"task": self.task, "is_success": is_success})
obs = self._format_obs(raw_obs)
if terminated or truncated:
info["final_info"] = {"task": self.task, "is_success": is_success}
return obs, reward, terminated, truncated, info
def render(self) -> np.ndarray | None:
if self.render_mode == "rgb_array":
return self._env.render()
return None
def close(self) -> None:
self._env.close()
def _make_env_fns(
*,
task: str,
n_envs: int,
image_size: int,
split: str,
episode_length: int,
gym_kwargs: dict[str, Any],
) -> list[Callable[[], RoboCasaEnv]]:
"""Build n_envs factory callables for a single task."""
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
return RoboCasaEnv(
task=task,
split=split,
image_size=image_size,
episode_length=episode_length,
**gym_kwargs,
)
return [partial(_make, i) for i in range(n_envs)]
def create_robocasa_envs(
tasks: str | Sequence[str],
n_envs: int,
image_size: int = 128,
split: str = "target",
episode_length: int = 500,
gym_kwargs: dict[str, Any] | None = None,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboCasa environments.
Args:
tasks: A single task name or list of task names (without "robocasa/" prefix).
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
n_envs: Number of parallel envs per task.
image_size: Square image resolution for all cameras.
split: RoboCasa dataset split "pretrain" or "target".
episode_length: Max steps per episode before truncation.
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
Returns:
dict[task_name][task_id=0] -> vec_env
"""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
if isinstance(tasks, str):
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
else:
task_list = [str(t).strip() for t in tasks if str(t).strip()]
if not task_list:
raise ValueError("`tasks` must contain at least one task name.")
gym_kwargs = dict(gym_kwargs or {})
out: dict[str, dict[int, Any]] = defaultdict(dict)
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
for task in task_list:
fns = _make_env_fns(
task=task,
n_envs=n_envs,
image_size=image_size,
split=split,
episode_length=episode_length,
gym_kwargs=gym_kwargs,
)
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
print(f" Built vec env | task={task} | n_envs={n_envs}")
return {suite: dict(task_map) for suite, task_map in out.items()}
+1 -1
View File
@@ -29,7 +29,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.types import RobotObservation
from lerobot.processor import RobotObservation
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
+1 -2
View File
@@ -23,8 +23,7 @@ import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.datasets.io_utils import write_json
from lerobot.datasets.utils import flatten_dict, unflatten_dict
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
+1 -1
View File
@@ -23,7 +23,7 @@ import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.datasets.io_utils import write_json
from lerobot.datasets.utils import write_json
from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.io_utils import deserialize_json_into_object
+3 -4
View File
@@ -24,8 +24,8 @@ import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
@@ -43,14 +43,13 @@ from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
transition_to_policy_action,
)
from lerobot.types import PolicyAction
from lerobot.utils.constants import (
ACTION,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
@@ -49,7 +49,7 @@ from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.types import EnvTransition, TransitionKey
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
HF_LEROBOT_HOME,
+1 -1
View File
@@ -36,7 +36,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
@@ -37,7 +37,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
+1 -1
View File
@@ -48,8 +48,8 @@ from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@@ -68,7 +68,7 @@ from lerobot.policies.utils import (
populate_queues,
)
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.device_utils import get_safe_dtype
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
@@ -374,11 +374,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
original_action_dim = self.config.action_feature.shape[0]
losses = losses[:, :, :original_action_dim]
loss_dict["losses_after_forward"] = losses.clone().mean().item()
if actions_is_pad is not None:
+2 -2
View File
@@ -23,8 +23,8 @@ from torch import nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import build_dataset_frame
from lerobot.types import PolicyAction, RobotAction, RobotObservation
from lerobot.datasets.utils import build_dataset_frame
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
+2 -2
View File
@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
self.vqvae_model.optimized_steps += 1
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
self.vqvae_model.discretized.fill_(True)
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
print("Finished discretizing action data!")
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():
+1 -1
View File
@@ -38,7 +38,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
OBS_PREFIX,
+7 -8
View File
@@ -14,7 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.types import (
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
create_transition,
transition_to_batch,
)
from .core import (
EnvAction,
EnvTransition,
PolicyAction,
@@ -22,13 +28,6 @@ from lerobot.types import (
RobotObservation,
TransitionKey,
)
from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
create_transition,
transition_to_batch,
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
from .factory import (
+1 -1
View File
@@ -25,9 +25,9 @@ from dataclasses import dataclass, field
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, PolicyAction
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .core import EnvTransition, PolicyAction
from .pipeline import (
ComplementaryDataProcessorStep,
ObservationProcessorStep,
+2 -1
View File
@@ -23,9 +23,10 @@ from typing import Any
import numpy as np
import torch
from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@singledispatch
def to_tensor(
@@ -17,8 +17,8 @@
from dataclasses import dataclass
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import PolicyAction, RobotAction
from .core import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
+2 -2
View File
@@ -25,9 +25,9 @@ from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.utils import get_safe_torch_device
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
+38
View File
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
return result
@dataclass
@ProcessorStepRegistry.register(name="robocasa_processor")
class RoboCasaProcessorStep(ObservationProcessorStep):
"""
Processes RoboCasa observations into LeRobot format.
The RoboCasaEnv wrapper returns:
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
- ``observation.robot_state``: (B, 16) float32 proprioception
This step remaps them to:
- ``observation.images.<cam_name>`` (unchanged tensor)
- ``observation.state`` (robot_state renamed)
"""
def _process_observation(self, observation: dict) -> dict:
processed = {}
obs_prefix = OBS_PREFIX # "observation."
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}."):
# Already in the right place; pass through
processed[key] = value
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
# Rename robot_state → observation.state
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
return processed
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def observation(self, observation: dict) -> dict:
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
+1 -2
View File
@@ -14,14 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.types import RobotAction, RobotObservation
from .converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from .core import RobotAction, RobotObservation
from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
@@ -17,9 +17,9 @@
from dataclasses import dataclass
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvAction, EnvTransition, PolicyAction
from .converters import to_tensor
from .core import EnvAction, EnvTransition, PolicyAction
from .hil_processor import TELEOP_ACTION_KEY
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
@@ -75,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Converts numpy action to torch tensor if action exists, otherwise passes through."""
from lerobot.types import TransitionKey
from .core import TransitionKey
self._current_transition = transition.copy()
new_transition = self._current_transition
+1 -2
View File
@@ -30,8 +30,7 @@ from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import (
ComplementaryDataProcessorStep,
InfoProcessorStep,
+1 -14
View File
@@ -26,10 +26,10 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.constants import ACTION
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
@@ -131,15 +131,6 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -158,7 +149,6 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -208,7 +198,6 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -220,8 +209,6 @@ class _NormalizationMixin:
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
self.stats = {}
+1 -1
View File
@@ -46,10 +46,10 @@ from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
from lerobot.utils.hub import HubMixin
from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
# Generic type variables for pipeline input and output.
TInput = TypeVar("TInput")
+1 -1
View File
@@ -30,7 +30,6 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
@@ -41,6 +40,7 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, RobotObservation, TransitionKey
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
+3 -11
View File
@@ -62,7 +62,7 @@ from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue
from lerobot.robots import so_follower # noqa: F401
@@ -77,8 +77,6 @@ from lerobot.transport.utils import (
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.types import TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
@@ -88,6 +86,7 @@ from lerobot.utils.transition import (
)
from lerobot.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
)
@@ -259,11 +258,6 @@ def act_with_policy(
policy = policy.eval()
assert isinstance(policy, nn.Module)
preprocessor, postprocessor = make_sac_pre_post_processors(
config=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
@@ -295,9 +289,7 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation)
# action = postprocessor.process_action(action)
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
+1 -13
View File
@@ -66,7 +66,6 @@ from lerobot.datasets.factory import make_dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.wandb_utils import WandBLogger
@@ -87,7 +86,6 @@ from lerobot.utils.constants import (
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
@@ -98,6 +96,7 @@ from lerobot.utils.train_utils import (
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
)
@@ -314,11 +313,6 @@ def add_actor_information_and_train(
assert isinstance(policy, nn.Module)
preprocessor, _ = make_sac_pre_post_processors(
config=cfg.policy,
dataset_stats=cfg.policy.dataset_stats,
)
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -414,9 +408,6 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations = preprocessor.process_observation(observations)
next_observations = preprocessor.process_observation(next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
@@ -476,9 +467,6 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observations = preprocessor.process_observation(observations)
next_observations = preprocessor.process_observation(next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
+1 -1
View File
@@ -98,7 +98,7 @@ class WandBLogger:
entity=self.cfg.entity,
name=self.job_name,
notes=self.cfg.notes,
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True) if self.cfg.add_tags else None,
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
dir=self.log_dir,
config=cfg.to_dict(),
# TODO(rcadene): try set to True
@@ -17,8 +17,8 @@
import logging
from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
@@ -17,8 +17,8 @@
import logging
from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
@@ -23,7 +23,7 @@ import cv2
import numpy as np
import requests
from lerobot.types import RobotAction, RobotObservation
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
@@ -33,40 +33,21 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
logger = logging.getLogger(__name__)
# Action feature keys
ACTION_LINEAR_VEL = "linear_velocity"
ACTION_ANGULAR_VEL = "angular_velocity"
ACTION_LINEAR_VEL = "linear.vel"
ACTION_ANGULAR_VEL = "angular.vel"
# Observation feature keys — cameras
# Observation feature keys
OBS_FRONT = "front"
OBS_REAR = "rear"
# Observation feature keys — telemetry
OBS_SPEED = "speed"
OBS_BATTERY_LEVEL = "battery_level"
OBS_ORIENTATION = "orientation"
OBS_GPS_LATITUDE = "gps_latitude"
OBS_GPS_LONGITUDE = "gps_longitude"
OBS_GPS_SIGNAL = "gps_signal"
OBS_SIGNAL_LEVEL = "signal_level"
OBS_LINEAR_VEL = "linear.vel"
OBS_BATTERY_LEVEL = "battery.level"
OBS_ORIENTATION_DEG = "orientation.deg"
OBS_GPS_LATITUDE = "gps.latitude"
OBS_GPS_LONGITUDE = "gps.longitude"
OBS_GPS_SIGNAL = "gps.signal"
OBS_SIGNAL_LEVEL = "signal.level"
OBS_VIBRATION = "vibration"
OBS_LAMP = "lamp"
# Observation feature keys — IMU sensors
OBS_ACCELEROMETER_X = "accelerometer_x"
OBS_ACCELEROMETER_Y = "accelerometer_y"
OBS_ACCELEROMETER_Z = "accelerometer_z"
OBS_GYROSCOPE_X = "gyroscope_x"
OBS_GYROSCOPE_Y = "gyroscope_y"
OBS_GYROSCOPE_Z = "gyroscope_z"
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
# Observation feature keys — wheel RPMs
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
OBS_LAMP_STATE = "lamp.state"
class EarthRoverMiniPlus(Robot):
@@ -173,60 +154,33 @@ class EarthRoverMiniPlus(Robot):
dict: Observation features with types/shapes:
- front: (480, 640, 3) - Front camera RGB image
- rear: (480, 640, 3) - Rear camera RGB image
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
- battery.level: float - Battery level (0-1, normalized from 0-100)
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
- gps.latitude: float - GPS latitude coordinate
- gps.longitude: float - GPS longitude coordinate
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
- signal.level: float - Network signal level (0-1, normalized from 0-5)
- vibration: float - Vibration sensor reading
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
- wheel_rpm_0: float - Wheel 0 RPM
- wheel_rpm_1: float - Wheel 1 RPM
- wheel_rpm_2: float - Wheel 2 RPM
- wheel_rpm_3: float - Wheel 3 RPM
- lamp.state: float - Lamp state (0=off, 1=on)
"""
return {
# Cameras (height, width, channels)
OBS_FRONT: (480, 640, 3),
OBS_REAR: (480, 640, 3),
# Telemetry
OBS_SPEED: float,
# Motion state
OBS_LINEAR_VEL: float,
# Robot state
OBS_BATTERY_LEVEL: float,
OBS_ORIENTATION: float,
OBS_ORIENTATION_DEG: float,
# GPS
OBS_GPS_LATITUDE: float,
OBS_GPS_LONGITUDE: float,
OBS_GPS_SIGNAL: float,
# Sensors
OBS_SIGNAL_LEVEL: float,
OBS_VIBRATION: float,
OBS_LAMP: float,
# IMU — accelerometer
OBS_ACCELEROMETER_X: float,
OBS_ACCELEROMETER_Y: float,
OBS_ACCELEROMETER_Z: float,
# IMU — gyroscope
OBS_GYROSCOPE_X: float,
OBS_GYROSCOPE_Y: float,
OBS_GYROSCOPE_Z: float,
# IMU — magnetometer
OBS_MAGNETOMETER_X: float,
OBS_MAGNETOMETER_Y: float,
OBS_MAGNETOMETER_Z: float,
# Wheel RPMs
OBS_WHEEL_RPM_0: float,
OBS_WHEEL_RPM_1: float,
OBS_WHEEL_RPM_2: float,
OBS_WHEEL_RPM_3: float,
OBS_LAMP_STATE: float,
}
@cached_property
@@ -235,8 +189,8 @@ class EarthRoverMiniPlus(Robot):
Returns:
dict: Action features with types:
- linear_velocity: float - Target linear velocity (-1 to 1)
- angular_velocity: float - Target angular velocity (-1 to 1)
- linear.vel: float - Target linear velocity
- angular.vel: float - Target angular velocity
"""
return {
ACTION_LINEAR_VEL: float,
@@ -247,29 +201,19 @@ class EarthRoverMiniPlus(Robot):
def get_observation(self) -> RobotObservation:
"""Get current robot observation from SDK.
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
Frames are decoded from base64 and converted from BGR to RGB format.
Robot telemetry is retrieved from /data endpoint.
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
[values..., timestamp]; the latest reading from each array is used.
Returns:
RobotObservation: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- speed: float - Current speed (raw SDK value)
- battery_level: float - Battery level (0-100)
- orientation: float - Robot orientation in degrees
- gps_latitude: float - GPS latitude coordinate
- gps_longitude: float - GPS longitude coordinate
- gps_signal: float - GPS signal strength (percentage)
- signal_level: float - Network signal level (0-5)
- vibration: float - Vibration sensor reading
- lamp: float - Lamp state (0=off, 1=on)
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
- wheel_rpm_0/1/2/3: float - Wheel RPMs
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
- battery.level: Battery level (0-1, normalized from 0-100)
- orientation.deg: Robot orientation (0-1, normalized from raw value)
- gps.latitude: GPS latitude coordinate
- gps.longitude: GPS longitude coordinate
- gps.signal: GPS signal strength (0-1, normalized from percentage)
- signal.level: Network signal level (0-1, normalized from 0-5)
- vibration: Vibration sensor reading
- lamp.state: Lamp state (0=off, 1=on)
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -291,41 +235,22 @@ class EarthRoverMiniPlus(Robot):
# Get robot state from SDK
robot_data = self._get_robot_data()
# Telemetry
observation[OBS_SPEED] = float(robot_data["speed"])
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
observation[OBS_VIBRATION] = float(robot_data["vibration"])
observation[OBS_LAMP] = float(robot_data["lamp"])
# Motion state
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
# Accelerometer — latest reading from accels array [x, y, z, ts]
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
observation[OBS_ACCELEROMETER_X] = accel[0]
observation[OBS_ACCELEROMETER_Y] = accel[1]
observation[OBS_ACCELEROMETER_Z] = accel[2]
# Robot state
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
# Gyroscope — latest reading from gyros array [x, y, z, ts]
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
observation[OBS_GYROSCOPE_X] = gyro[0]
observation[OBS_GYROSCOPE_Y] = gyro[1]
observation[OBS_GYROSCOPE_Z] = gyro[2]
# GPS data
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
# Magnetometer — latest reading from mags array [x, y, z, ts]
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
observation[OBS_MAGNETOMETER_X] = mag[0]
observation[OBS_MAGNETOMETER_Y] = mag[1]
observation[OBS_MAGNETOMETER_Z] = mag[2]
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
observation[OBS_WHEEL_RPM_0] = rpm[0]
observation[OBS_WHEEL_RPM_1] = rpm[1]
observation[OBS_WHEEL_RPM_2] = rpm[2]
observation[OBS_WHEEL_RPM_3] = rpm[3]
# Sensors
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
observation[OBS_VIBRATION] = robot_data["vibration"]
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
return observation
@@ -335,12 +260,11 @@ class EarthRoverMiniPlus(Robot):
Args:
action: Action dict with keys:
- linear_velocity: Target linear velocity (-1 to 1)
- angular_velocity: Target angular velocity (-1 to 1)
- linear.vel: Target linear velocity (-1 to 1)
- angular.vel: Target angular velocity (-1 to 1)
Returns:
RobotAction: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
@@ -348,14 +272,18 @@ class EarthRoverMiniPlus(Robot):
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
# Send command to SDK
try:
self._send_command_to_sdk(linear, angular)
except Exception as e:
logger.error(f"Error sending action: {e}")
# Return action in format matching action_features
return {
ACTION_LINEAR_VEL: linear,
ACTION_ANGULAR_VEL: angular,
@@ -466,27 +394,11 @@ class EarthRoverMiniPlus(Robot):
logger.error(f"Error decoding image: {e}")
return None
@staticmethod
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
"""Extract the latest sensor reading from an SDK sensor array.
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
This helper returns the *n_values* leading floats from the last entry,
falling back to zeros when the key is missing or the array is empty.
"""
readings = robot_data.get(key)
if readings and len(readings) > 0:
latest = readings[-1]
return [float(v) for v in latest[:n_values]]
return [0.0] * n_values
def _get_robot_data(self) -> dict:
"""Get robot telemetry data from SDK.
Returns:
dict: Robot telemetry data including battery, speed, orientation, GPS,
and sensor arrays (accels, gyros, mags, rpms):
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
- Current data (if request succeeds)
- Cached data (if request fails but cache exists)
- Default values (if request fails and no cache exists yet)
@@ -508,23 +420,19 @@ class EarthRoverMiniPlus(Robot):
# Fallback: use cache or default values
if self._last_robot_data is not None:
return self._last_robot_data
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
"accels": [],
"gyros": [],
"mags": [],
"rpms": [],
}
else:
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
}
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
"""Send control command to SDK.
+1 -1
View File
@@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
+1 -1
View File
@@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
@@ -24,7 +24,7 @@ from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
+1 -1
View File
@@ -28,7 +28,7 @@ from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.types import RobotAction, RobotObservation
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot

Some files were not shown because too many files have changed in this diff Show More