mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 026e4c937d | |||
| efe8c09fca | |||
| 58eecad8a4 | |||
| c7fd1f47d1 | |||
| 6370949e5c | |||
| 46b97da168 | |||
| e69be57a66 | |||
| c3a6ddb668 | |||
| dad661012d | |||
| 219c08ccb8 | |||
| e64fa667c3 | |||
| d9ec3a6fa2 | |||
| d90e4bcfd3 | |||
| 9d3b62aa61 | |||
| 7c2ec31793 | |||
| a07b1d76f1 | |||
| 2ec1dafcc2 | |||
| 06385902df | |||
| 2d6259156b | |||
| 0db5f66dda | |||
| efee611403 | |||
| c15b75e3da | |||
| f311ca3dce | |||
| 19c6adef85 | |||
| 96b7f3dae0 | |||
| 885ef91892 | |||
| b0efa73520 |
+4
-4
@@ -2,7 +2,7 @@
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) and our [AI policy](https://github.com/huggingface/lerobot/blob/main/AI_POLICY.md).
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
@@ -32,7 +32,7 @@ git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
|
||||
### 2. Environment Installation
|
||||
|
||||
Please follow our [Installation Guide](./docs/source/installation.mdx) for the environment setup & installation from source.
|
||||
Please follow our [Installation Guide](https://huggingface.co/docs/lerobot/installation) for the environment setup & installation from source.
|
||||
|
||||
## Running Tests & Quality Checks
|
||||
|
||||
@@ -75,8 +75,8 @@ pytest -sv tests/test_specific_feature.py
|
||||
|
||||
Use the templates for required fields and examples.
|
||||
|
||||
- **Issues:** Follow the [ticket template](./.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](./.github/PULL_REQUEST_TEMPLATE.md).
|
||||
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ If you are referencing our research or the academic paper, please also cite our
|
||||
|
||||
## Contribute
|
||||
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](./CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
We welcome contributions from everyone in the community! To get started, please read our [CONTRIBUTING.md](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) guide. Whether you're adding a new feature, improving documentation, or fixing a bug, your help and feedback are invaluable. We're incredibly excited about the future of open-source robotics and can't wait to work with you on what's next—thank you for your support!
|
||||
|
||||
<p align="center">
|
||||
<img alt="SO101 Video" src="./media/readme/so100_video.webp" width="640px">
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
# docker build -f docker/Dockerfile.user -t lerobot-user .
|
||||
# docker run -it --rm lerobot-user
|
||||
|
||||
# With USB physical access : docker run -it --device=/dev/ -v /dev/:/dev/ --rm lerobot-user
|
||||
|
||||
# Configure the base image
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
|
||||
@@ -204,22 +204,26 @@ Replace `your_username/dataset_name` with your Hugging Face username and a name
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 things)**:
|
||||
**Your Actions (2 features)**:
|
||||
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
- `linear_velocity`: How much you moved forward/backward
|
||||
- `angular_velocity`: How much you turned left/right
|
||||
|
||||
**Robot Observations (12 things)**:
|
||||
**Robot Observations (24 features)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Orientation
|
||||
- GPS (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp status (on/off)
|
||||
- Lamp state (on/off)
|
||||
- Accelerometer (x, y, z)
|
||||
- Gyroscope (x, y, z)
|
||||
- Magnetometer (x, y, z)
|
||||
- Wheel RPMs (4 wheels)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
|
||||
+84
-43
@@ -12,36 +12,59 @@ The Unitree G1 humanoid is now supported in LeRobot! You can teleoperate, train
|
||||
|
||||
## Part 1: Getting Started
|
||||
|
||||
### Install LeRobot on Your Machine
|
||||
### Install the Unitree SDK
|
||||
|
||||
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation). Tested with `unitree_sdk2py==1.0.1` and `cyclonedds==0.10.2`:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
cd unitree_sdk2_python
|
||||
pip install -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
<Tip>
|
||||
For now, pinocchio must be installed from conda-forge (not pip) to include the
|
||||
CasADi bindings needed for arm IK.
|
||||
</Tip>
|
||||
|
||||
### Test the Installation (Simulation)
|
||||
|
||||
The simulation environment has its own dependencies. Check the Simulation environment dependencies: [Unitree G1 Mujoco EnvHub](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main).
|
||||
|
||||
```bash
|
||||
pip install mujoco loguru msgpack msgpack-numpy
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.id=wbc_unitree \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=true
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30, "warmup_s": 5}}' \
|
||||
--display_data=true \
|
||||
--robot.controller=GrootLocomotionController
|
||||
```
|
||||
|
||||
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1.
|
||||
This will launch a [MuJoCo sim instance](https://huggingface.co/lerobot/unitree-g1-mujoco/tree/main) for the G1. You can connect a gamepad to your machine before launching in order to control the robot's locomotion in sim. We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) via `--robot.controller`.
|
||||
|
||||
- Press `9` to release the robot
|
||||
- Press `7` / `8` to increase / decrease waist height
|
||||
|
||||
### Connect to the Robot
|
||||
### Connect to the Physical Robot
|
||||
|
||||
The G1's Ethernet IP is fixed at `192.168.123.164`. Your machine must have a static IP on the same subnet: `192.168.123.x` where `x ≠ 164`.
|
||||
|
||||
@@ -59,37 +82,11 @@ ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
### Install LeRobot on the G1
|
||||
### Share Internet via Ethernet
|
||||
|
||||
From the robot:
|
||||
The G1 needs internet access to clone repos and install packages. Share your laptop's connection over Ethernet:
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
> **Note:** The Unitree SDK requires CycloneDDS v0.10.2. See the [Unitree SDK docs](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Wi-Fi connectivity is blocked by default on the G1. To activate:
|
||||
|
||||
```bash
|
||||
sudo rfkill unblock all
|
||||
sudo ip link set wlan0 up
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
**On your laptop** (share internet via Ethernet):
|
||||
**On your laptop:**
|
||||
|
||||
```bash
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
@@ -100,7 +97,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the G1** (set default route through your laptop):
|
||||
**On the G1:**
|
||||
|
||||
```bash
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
@@ -111,6 +108,45 @@ echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Install the Unitree SDK on the G1
|
||||
|
||||
Follow the [unitree_sdk2_python installation guide](https://github.com/unitreerobotics/unitree_sdk2_python#installation):
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.12
|
||||
conda activate lerobot
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python
|
||||
python -m pip install -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
### Install LeRobot on the G1
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
conda install -c conda-forge "pinocchio>=3.0.0,<4.0.0"
|
||||
python -m pip install -e '.[unitree_g1]'
|
||||
```
|
||||
|
||||
<Tip>
|
||||
For now, pinocchio must be installed from conda-forge (not pip) to include the
|
||||
CasADi bindings needed for arm IK.
|
||||
</Tip>
|
||||
|
||||
### (Optional) Enable WiFi on the Robot
|
||||
|
||||
For wireless SSH access, you can enable WiFi on the G1 (it's blocked by default):
|
||||
|
||||
```bash
|
||||
sudo rfkill unblock all
|
||||
sudo ip link set wlan0 up
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
**Connect to a WiFi network:**
|
||||
|
||||
```bash
|
||||
@@ -125,7 +161,7 @@ sudo nmcli connection up "YourNetwork"
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
You can now SSH over WiFi:
|
||||
You can then SSH over WiFi instead of Ethernet:
|
||||
|
||||
```bash
|
||||
ssh unitree@<ROBOT_WIFI_IP>
|
||||
@@ -134,18 +170,23 @@ ssh unitree@<ROBOT_WIFI_IP>
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Teleoperation & Locomotion
|
||||
## Part 2: Teleoperation & Locomotion
|
||||
|
||||
### Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
On the robot (from `~/lerobot`):
|
||||
|
||||
```bash
|
||||
cd ~/lerobot
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py --camera
|
||||
```
|
||||
|
||||
### Run the Locomotion Policy
|
||||
|
||||
You can run the teleoperation client from your laptop over Ethernet, over WiFi (experimental), or directly on the robot itself. Mind potential latency introduced by your network.
|
||||
|
||||
**From your laptop:**
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
@@ -158,13 +199,13 @@ lerobot-teleoperate \
|
||||
--robot.controller=HolosomaLocomotionController
|
||||
```
|
||||
|
||||
We support both [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) and [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl).
|
||||
We support both [GrootLocomotionController](https://github.com/NVlabs/GR00T-WholeBodyControl) and [HolosomaLocomotionController](https://github.com/amazon-far/holosoma) via `--robot.controller`.
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Loco-Manipulation with the Homunculus Exoskeleton
|
||||
## Part 3: Loco-Manipulation with the Homunculus Exoskeleton
|
||||
|
||||
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Assembly instructions [here](https://github.com/nepyope/hmc_exo).
|
||||
We provide a loco-manipulation solution via the Homunculus Exoskeleton — an open-source 7 DoF exoskeleton for whole-body control. Check it out [here](https://github.com/nepyope/hmc_exo).
|
||||
|
||||
### Calibrate
|
||||
|
||||
@@ -205,7 +246,7 @@ Example dataset: [nepyope/unitree_box_move_blue_full](https://huggingface.co/dat
|
||||
|
||||
---
|
||||
|
||||
## Part 5: Training & Inference
|
||||
## Part 4: Training & Inference
|
||||
|
||||
### Train
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Action consistency analysis for imitation learning datasets.
|
||||
|
||||
Two parallel analyses per dataset:
|
||||
1. State-based: KNN in joint-state space → action chunk variance
|
||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||
|
||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||
agree on where the data is inconsistent — and images are what the policy
|
||||
primarily sees.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
MAX_FRAMES = 100_000
|
||||
K_NEIGHBORS = 50
|
||||
ACTION_CHUNK_SIZE = 30
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||
ENCODE_BATCH_SIZE = 512
|
||||
SEED = 42
|
||||
DPI = 150
|
||||
|
||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
||||
)
|
||||
|
||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data helpers ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=[
|
||||
"meta/**",
|
||||
"data/**",
|
||||
f"videos/{camera_key}/**",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_action_chunks(
|
||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||
Returns (action_chunks, valid_mask).
|
||||
"""
|
||||
n = len(actions)
|
||||
act_dim = actions.shape[1]
|
||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
||||
valid = np.zeros(n, dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
end = i + chunk_size
|
||||
if end > n:
|
||||
continue
|
||||
if episode_ids[i] != episode_ids[end - 1]:
|
||||
continue
|
||||
chunks[i] = actions[i:end].ravel()
|
||||
valid[i] = True
|
||||
|
||||
return chunks, valid
|
||||
|
||||
|
||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||
"""
|
||||
Load observation.state and action, build action chunks, subsample, normalize.
|
||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
n_total = len(df)
|
||||
print(f" Total frames: {n_total:,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
action_col = next((c for c in df.columns if c == "action"), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
if action_col is None:
|
||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
||||
episode_all = df[ep_col].values.astype(np.int64)
|
||||
|
||||
n_dim = state_all.shape[1]
|
||||
act_dim = action_all.shape[1]
|
||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
print(" Building action chunks …")
|
||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||
valid_idx = np.where(valid)[0]
|
||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||
|
||||
if len(valid_idx) > max_frames:
|
||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||
else:
|
||||
chosen = valid_idx
|
||||
print(f" Using {len(chosen):,} frames")
|
||||
|
||||
state_raw = state_all[chosen]
|
||||
action_raw = action_chunks[chosen]
|
||||
episode_ids = episode_all[chosen]
|
||||
|
||||
state_mean = state_raw.mean(axis=0)
|
||||
state_std = state_raw.std(axis=0)
|
||||
state_std[state_std < 1e-8] = 1.0
|
||||
state_norm = (state_raw - state_mean) / state_std
|
||||
|
||||
action_mean = action_raw.mean(axis=0)
|
||||
action_std = action_raw.std(axis=0)
|
||||
action_std[action_std < 1e-8] = 1.0
|
||||
action_norm = (action_raw - action_mean) / action_std
|
||||
|
||||
return {
|
||||
"state_raw": state_raw,
|
||||
"state_norm": state_norm,
|
||||
"action_raw": action_raw,
|
||||
"action_norm": action_norm,
|
||||
"episode_ids": episode_ids,
|
||||
"episode_all": episode_all,
|
||||
"left_joint_idx": left_idx,
|
||||
"right_joint_idx": right_idx,
|
||||
"n_total": n_total,
|
||||
"chosen_idx": chosen,
|
||||
"df": df,
|
||||
}
|
||||
|
||||
|
||||
# ── Video → frame extraction ──────────────────────────────
|
||||
|
||||
|
||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
"""
|
||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||
file_col = f"videos/{camera_key}/file_index"
|
||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{camera_key}/chunk_index"
|
||||
file_col = f"{camera_key}/file_index"
|
||||
ts_from = f"{camera_key}/from_timestamp"
|
||||
|
||||
lookup: dict[int, dict] = {}
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||
lookup[int(row["episode_index"])] = {
|
||||
"video_path": local / video_rel,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"fps": fps,
|
||||
}
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
|
||||
|
||||
# ── SigLIP encoding ─────────────────────────────────────
|
||||
|
||||
|
||||
def encode_frames_siglip(
|
||||
frames: list[np.ndarray | None],
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
done = min(batch_start + batch_size, len(valid_indices))
|
||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||
|
||||
del model, processor
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
|
||||
# ── KNN consistency ─────────────────────────────────────
|
||||
|
||||
|
||||
def compute_consistency(
|
||||
features: np.ndarray,
|
||||
action_norm: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
k: int,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||
Return per-frame action variance (mean across action dims).
|
||||
"""
|
||||
n = len(features)
|
||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||
tree = cKDTree(features)
|
||||
|
||||
k_query = min(k * 3, n - 1)
|
||||
print(f" Querying {k_query} neighbors per frame …")
|
||||
_dists, indices = tree.query(features, k=k_query + 1)
|
||||
indices = indices[:, 1:]
|
||||
|
||||
print(f" Computing cross-episode action variance ({label}) …")
|
||||
variance = np.zeros(n)
|
||||
for i in range(n):
|
||||
ep_i = episode_ids[i]
|
||||
neighbors = indices[i]
|
||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
||||
if len(cross_ep) < 2:
|
||||
variance[i] = 0.0
|
||||
continue
|
||||
neighbor_actions = action_norm[cross_ep]
|
||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor("#0d1117")
|
||||
ax.tick_params(colors="#555", labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
|
||||
|
||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||
_style_ax(ax)
|
||||
median_var = np.median(variance)
|
||||
mean_var = np.mean(variance)
|
||||
nonzero = variance[variance > 0]
|
||||
if len(nonzero) > 0:
|
||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||
ax.set_xscale("log")
|
||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_episode_curves(
|
||||
ax: plt.Axes,
|
||||
var_state: np.ndarray,
|
||||
var_image: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
title: str,
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
unique_eps = np.unique(episode_ids)
|
||||
|
||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||
|
||||
sorted_s = np.sort(ep_means_s)[::-1]
|
||||
sorted_i = np.sort(ep_means_i)[::-1]
|
||||
ep_x = np.arange(len(unique_eps))
|
||||
|
||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||
|
||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
order = np.argsort(variance)
|
||||
pts = tcp_xz[order]
|
||||
var_sorted = variance[order]
|
||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||
sc = ax.scatter(
|
||||
pts[:, 0],
|
||||
pts[:, 1],
|
||||
c=var_sorted,
|
||||
cmap=CONSISTENCY_CMAP,
|
||||
s=0.5,
|
||||
alpha=0.6,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True,
|
||||
)
|
||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.set_aspect("equal")
|
||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
4-row x N-column figure:
|
||||
Row 0: State-based variance histogram
|
||||
Row 1: Image-based variance histogram
|
||||
Row 2: Per-episode curves (both overlaid)
|
||||
Row 3: Spatial heatmap (image-based variance)
|
||||
"""
|
||||
n_ds = len(results)
|
||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[:, np.newaxis]
|
||||
|
||||
headline_parts = []
|
||||
for col, r in enumerate(results):
|
||||
label = r["label"]
|
||||
var_s = r["var_state"]
|
||||
var_i = r["var_image"]
|
||||
tcp_xz = r["tcp_xz"]
|
||||
episode_ids = r["episode_ids"]
|
||||
|
||||
med_s = np.median(var_s)
|
||||
med_i = np.median(var_i)
|
||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||
|
||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||
_plot_histogram(
|
||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||
)
|
||||
_plot_episode_curves(
|
||||
axes[2, col],
|
||||
var_s,
|
||||
var_i,
|
||||
episode_ids,
|
||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||
)
|
||||
_plot_heatmap(
|
||||
axes[3, col],
|
||||
fig,
|
||||
tcp_xz,
|
||||
var_i,
|
||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||
+ " | ".join(headline_parts),
|
||||
color="white",
|
||||
fontsize=15,
|
||||
y=0.99,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
rng = np.random.default_rng(SEED)
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id, CAMERA_KEY)
|
||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||
|
||||
# --- State-based KNN ---
|
||||
var_state = compute_consistency(
|
||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||
)
|
||||
print(
|
||||
f" State variance: median={np.median(var_state):.4f} "
|
||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||
)
|
||||
|
||||
# --- Image-based KNN ---
|
||||
print("\n Preparing image embeddings …")
|
||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||
del frames # free memory
|
||||
|
||||
var_image = compute_consistency(
|
||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||
)
|
||||
print(
|
||||
f" Image variance: median={np.median(var_image):.4f} "
|
||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||
)
|
||||
|
||||
# FK for spatial heatmap
|
||||
print(" Computing FK for spatial heatmap …")
|
||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||
left_rad = _detect_and_convert(left_raw)
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
||||
tcp_xz = left_tcp[:, [0, 2]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"var_state": var_state,
|
||||
"var_image": var_image,
|
||||
"episode_ids": data["episode_ids"],
|
||||
"tcp_xz": tcp_xz,
|
||||
"n_total": data["n_total"],
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||
render(results, out)
|
||||
|
||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||
worst_summary = {}
|
||||
for r in results:
|
||||
unique_eps = np.unique(r["episode_ids"])
|
||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||
print(f"✓ Saved worst episodes: {worst_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
||||
decodes them, and tiles into a single image.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
GRID_COLS = 15
|
||||
GRID_ROWS = 10
|
||||
THUMB_WIDTH = 160
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
SEED = 1
|
||||
|
||||
|
||||
def download_metadata(repo_id: str) -> Path:
|
||||
"""Download only metadata (no videos yet)."""
|
||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
cam = CAMERA_KEY
|
||||
else:
|
||||
cam = video_keys[0]
|
||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
chunk_col = f"videos/{cam}/chunk_index"
|
||||
file_col = f"videos/{cam}/file_index"
|
||||
ts_from = f"videos/{cam}/from_timestamp"
|
||||
ts_to = f"videos/{cam}/to_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{cam}/chunk_index"
|
||||
file_col = f"{cam}/file_index"
|
||||
ts_from = f"{cam}/from_timestamp"
|
||||
ts_to = f"{cam}/to_timestamp"
|
||||
|
||||
episodes = []
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
episodes.append(
|
||||
{
|
||||
"episode_index": int(row["episode_index"]),
|
||||
"chunk_index": ci,
|
||||
"file_index": fi,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"to_ts": float(row[ts_to]),
|
||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
||||
}
|
||||
)
|
||||
return cam, episodes, fps
|
||||
|
||||
|
||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
||||
picks = []
|
||||
for _ in range(n):
|
||||
ep = rng.choice(episodes)
|
||||
duration = ep["to_ts"] - ep["from_ts"]
|
||||
if duration <= 0:
|
||||
continue
|
||||
t = ep["from_ts"] + rng.random() * duration
|
||||
picks.append({**ep, "seek_ts": t})
|
||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
||||
return picks
|
||||
|
||||
|
||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
||||
"""Download only the video files we need."""
|
||||
needed = sorted({p["video_rel"] for p in picks})
|
||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=needed,
|
||||
)
|
||||
|
||||
|
||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
||||
"""Decode a single frame at the given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
return frame if ret else None
|
||||
|
||||
|
||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
||||
if not frames:
|
||||
raise RuntimeError("No frames decoded")
|
||||
|
||||
h0, w0 = frames[0].shape[:2]
|
||||
thumb_h = int(thumb_w * h0 / w0)
|
||||
|
||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
||||
|
||||
rows = []
|
||||
for i in range(0, len(thumbs), cols):
|
||||
row_thumbs = thumbs[i : i + cols]
|
||||
while len(row_thumbs) < cols:
|
||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
||||
rows.append(np.hstack(row_thumbs))
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rng = random.Random(SEED)
|
||||
n_frames = GRID_COLS * GRID_ROWS
|
||||
|
||||
local = download_metadata(REPO_ID)
|
||||
cam, episodes, fps = load_video_info(local)
|
||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
||||
download_video_files(REPO_ID, local, picks)
|
||||
|
||||
print(f"[3/3] Decoding {n_frames} frames …")
|
||||
frames: list[np.ndarray] = []
|
||||
for p in picks:
|
||||
vp = local / p["video_rel"]
|
||||
if not vp.exists():
|
||||
print(f" SKIP: {p['video_rel']} not found")
|
||||
continue
|
||||
frame = extract_frame(vp, p["seek_ts"])
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
||||
|
||||
safe_name = REPO_ID.replace("/", "_")
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||
and draws the progress line directly on each frame (no panel, no axes).
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
||||
]
|
||||
CAMERA_KEY = (
|
||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
||||
)
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Progress line spans the full video height
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6 # white edge thickness
|
||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the files needed for this episode."""
|
||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||
# We'll download meta + sarm first, then figure out chunks.
|
||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||
local = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local
|
||||
|
||||
|
||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
# Find video keys (keys whose dtype=="video")
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
first_cam = CAMERA_KEY
|
||||
else:
|
||||
first_cam = video_keys[0]
|
||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
||||
|
||||
# Load all episode-meta parquet files and find our episode
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
df = pd.read_parquet(pq)
|
||||
ep_rows.append(df)
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
row = ep_df[ep_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
# Extract video chunk/file index for first camera
|
||||
# Try both dot and slash variants of the key
|
||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||
file_col = f"videos/{first_cam}/file_index"
|
||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||
to_col = f"videos/{first_cam}/to_timestamp"
|
||||
|
||||
# Some datasets use different column naming
|
||||
if chunk_col not in row.index:
|
||||
# Try without the 'videos/' prefix
|
||||
chunk_col = f"{first_cam}/chunk_index"
|
||||
file_col = f"{first_cam}/file_index"
|
||||
ts_col = f"{first_cam}/from_timestamp"
|
||||
to_col = f"{first_cam}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_idx = int(row[chunk_col])
|
||||
file_idx = int(row[file_col])
|
||||
from_ts = float(row[ts_col])
|
||||
to_ts = float(row[to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=first_cam,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# Load task name for this episode
|
||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||
task_name = ""
|
||||
try:
|
||||
# Prefer the 'tasks' list directly on the episode row
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
task_name = str(tasks_val[0])
|
||||
else:
|
||||
task_name = str(tasks_val).strip("[]'")
|
||||
else:
|
||||
tasks_pq = local / "meta" / "tasks.parquet"
|
||||
if tasks_pq.exists():
|
||||
tasks_df = pd.read_parquet(tasks_pq)
|
||||
# Row index is the task string; task_index column is the int
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
task_name = str(match.index[0])
|
||||
print(f" Task name: '{task_name}'")
|
||||
except Exception as e:
|
||||
print(f" WARNING: could not load task name: {e}")
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"first_cam": first_cam,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_idx,
|
||||
"file_index": file_idx,
|
||||
"from_ts": from_ts,
|
||||
"to_ts": to_ts,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already present."""
|
||||
video_path = local / video_rel
|
||||
if video_path.exists():
|
||||
print(f" Video already cached: {video_path}")
|
||||
return video_path
|
||||
print(f"[2/5] Downloading video file {video_rel} …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||
pq_path = local / "sarm_progress.parquet"
|
||||
if not pq_path.exists():
|
||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||
return None
|
||||
df = pd.read_parquet(pq_path)
|
||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||
ep_df = df[df["episode_index"] == episode].copy()
|
||||
if ep_df.empty:
|
||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||
return None
|
||||
ep_df = ep_df.sort_values("frame_index")
|
||||
|
||||
# Prefer dense, fall back to sparse
|
||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||
prog_col = "progress_dense"
|
||||
elif "progress_sparse" in ep_df.columns:
|
||||
prog_col = "progress_sparse"
|
||||
else:
|
||||
# Last resort: any column with 'progress' in the name
|
||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||
if not prog_cols:
|
||||
return None
|
||||
prog_col = prog_cols[0]
|
||||
|
||||
print(f" Using progress column: '{prog_col}'")
|
||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||
|
||||
|
||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||
duration = to_ts - from_ts
|
||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-ss",
|
||||
str(from_ts),
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"fast",
|
||||
"-crf",
|
||||
"18",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||
return out_path
|
||||
|
||||
|
||||
def precompute_pixels(
|
||||
progress_data: np.ndarray,
|
||||
n_frames: int,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Map each progress sample to pixel coordinates.
|
||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||
x spans full video width; y maps progress [0,1] to graph band.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
graph_h = y_bot - y_top
|
||||
|
||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||
# progress=1 → y_top, progress=0 → y_bot
|
||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||
|
||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||
|
||||
|
||||
def progress_color(t: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||
r = int(255 * (1.0 - t))
|
||||
g = int(255 * t)
|
||||
return (0, g, r) # BGR
|
||||
|
||||
|
||||
def prerender_fill(
|
||||
pixels: np.ndarray,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
poly = np.concatenate(
|
||||
[
|
||||
pixels,
|
||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_img
|
||||
|
||||
|
||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||
if x_max <= 0:
|
||||
return
|
||||
roi_b = base[:, :x_max]
|
||||
roi_o = overlay_bgra[:, :x_max]
|
||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||
roi_b[:] = np.clip(
|
||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
pos: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw text with a dark outline for readability on any background."""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_video(
|
||||
clip_path: Path,
|
||||
progress_data: np.ndarray,
|
||||
out_path: Path,
|
||||
fps: float,
|
||||
frame_h: int,
|
||||
frame_w: int,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||
|
||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||
|
||||
# 1.0 reference line overlay (full width, drawn once)
|
||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_vals = progress_data[:, 1].astype(float)
|
||||
|
||||
print(f"[4/4] Compositing {n_total} frames …")
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||
|
||||
fi = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||
|
||||
# 1. reference line (full width, always)
|
||||
alpha_composite(frame, ref_img, frame_w)
|
||||
|
||||
# 2. grey fill under curve up to current x
|
||||
alpha_composite(frame, fill_img, x_cur)
|
||||
|
||||
# 3. progress line — single color that transitions red→green over time
|
||||
if n_drawn >= 2:
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
line_col = progress_color(t_cur)
|
||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[pts],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 4. score — bottom right
|
||||
if n_drawn > 0:
|
||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
||||
sx = frame_w - tw - 12
|
||||
sy = frame_h - 12
|
||||
# coloured score matching current gradient position
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
score_col = progress_color(t_cur)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_col,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# 5. task name — top centre
|
||||
if task_name:
|
||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
tx = max((frame_w - tw) // 2, 4)
|
||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
fi += 1
|
||||
if fi % 100 == 0:
|
||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||
|
||||
cap.release()
|
||||
writer.release()
|
||||
print()
|
||||
|
||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||
gif_path = out_path.with_suffix(".gif")
|
||||
palette = out_path.parent / "_palette.png"
|
||||
r1 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r1.returncode != 0:
|
||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||
r2 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-i",
|
||||
str(palette),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r2.returncode != 0:
|
||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
palette.unlink(missing_ok=True)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(repo_id: str, episode: int):
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {repo_id} | episode {episode}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Download metadata
|
||||
local = download_episode(repo_id, episode)
|
||||
print(f" Local cache: {local}")
|
||||
|
||||
# 2. Read episode metadata
|
||||
ep_meta = load_episode_meta(local, episode)
|
||||
print(f" Episode meta: {ep_meta}")
|
||||
|
||||
# 3. Download video file
|
||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||
|
||||
# 4. Extract clip
|
||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||
|
||||
# 5. Load progress data
|
||||
progress_data = load_progress(local, episode)
|
||||
if progress_data is None:
|
||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||
return
|
||||
|
||||
n_progress = len(progress_data)
|
||||
print(f" Progress frames: {n_progress}")
|
||||
|
||||
# 6. Get clip dimensions
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||
cap.release()
|
||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||
|
||||
# 7. Composite (draw line directly on frames)
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final = composite_video(
|
||||
clip_path,
|
||||
progress_data,
|
||||
out_path,
|
||||
actual_fps,
|
||||
frame_h,
|
||||
frame_w,
|
||||
task_name=ep_meta.get("task_name", ""),
|
||||
)
|
||||
clip_path.unlink(missing_ok=True)
|
||||
print(f"\n✓ Done: {final}")
|
||||
return final
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
for cfg in DATASETS:
|
||||
try:
|
||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||
if out:
|
||||
results.append(out)
|
||||
except Exception as e:
|
||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Output files:")
|
||||
for r in results:
|
||||
print(f" {r}")
|
||||
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
||||
kinematics per episode, clusters trajectories with K-means, and renders
|
||||
2D projections comparing dataset coverage and multimodality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
N_CLUSTERS = 10
|
||||
WAYPOINTS = 50
|
||||
SEED = 42
|
||||
DPI = 180
|
||||
|
||||
CLUSTER_COLORS = [
|
||||
"#e6194b",
|
||||
"#3cb44b",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabed4",
|
||||
"#dcbeff",
|
||||
"#9a6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
# FK chains extracted from OpenArm bimanual URDF.
|
||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
"""Try to find left/right joint indices from info.json feature names."""
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet only) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "data/**"],
|
||||
ignore_patterns=["*.mp4", "videos/**"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
||||
f = traj.shape[0]
|
||||
if f == n_waypoints:
|
||||
return traj
|
||||
old_t = np.linspace(0, 1, f)
|
||||
new_t = np.linspace(0, 1, n_waypoints)
|
||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
||||
|
||||
|
||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
||||
"""
|
||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
||||
Uses all episodes in the dataset for a fair comparison.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
print(f" Total frames: {len(df):,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
|
||||
first = df[state_col].iloc[0]
|
||||
if not hasattr(first, "__len__"):
|
||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
||||
|
||||
state = np.stack(df[state_col].values).astype(np.float64)
|
||||
n_dim = state.shape[1]
|
||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
episode_ids = df[ep_col].values
|
||||
unique_eps = np.unique(episode_ids)
|
||||
print(f" Episodes: {len(unique_eps):,}")
|
||||
|
||||
left_raw = state[:, left_idx]
|
||||
right_raw = state[:, right_idx]
|
||||
left_all = _detect_and_convert(left_raw)
|
||||
right_all = _detect_and_convert(right_raw)
|
||||
|
||||
print(" Computing FK per episode …")
|
||||
trajectories = []
|
||||
for ep_id in unique_eps:
|
||||
mask = episode_ids == ep_id
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
||||
if len(left_tcp) < 3:
|
||||
continue
|
||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
||||
|
||||
print(f" Valid trajectories: {len(trajectories):,}")
|
||||
return trajectories
|
||||
|
||||
|
||||
# ── Clustering ──────────────────────────────────────────
|
||||
|
||||
|
||||
def cluster_trajectories(
|
||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
K-means on resampled trajectory features.
|
||||
Combines left+right TCP into a single feature vector per episode.
|
||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
||||
"""
|
||||
feat_vecs = []
|
||||
for t in trajectories:
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
||||
feat_matrix = np.array(feat_vecs)
|
||||
|
||||
k = min(n_clusters, len(feat_vecs))
|
||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
||||
labels = km.fit_predict(feat_matrix)
|
||||
|
||||
centroids_flat = km.cluster_centers_
|
||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
||||
for ci in range(k):
|
||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
||||
|
||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
||||
spread = np.zeros(k)
|
||||
for ci in range(k):
|
||||
members = np.where(labels == ci)[0]
|
||||
if len(members) == 0:
|
||||
continue
|
||||
centroid_left = centroid_trajs[ci, :, :3]
|
||||
centroid_right = centroid_trajs[ci, :, 3:]
|
||||
dists = []
|
||||
for mi in members:
|
||||
t = trajectories[mi]
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
||||
dists.append((d_left + d_right) / 2)
|
||||
spread[ci] = np.mean(dists)
|
||||
|
||||
return labels, centroid_trajs, spread
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
PROJ_VIEWS = [
|
||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
||||
]
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
||||
"""
|
||||
n_ds = len(results)
|
||||
n_proj = len(PROJ_VIEWS)
|
||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[np.newaxis, :]
|
||||
|
||||
for row, r in enumerate(results):
|
||||
trajectories = r["trajectories"]
|
||||
labels = r["labels"]
|
||||
centroids = r["centroids"]
|
||||
k = centroids.shape[0]
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=k)
|
||||
size_order = np.argsort(-cluster_sizes)
|
||||
pcts = cluster_sizes / len(labels) * 100
|
||||
spread = r["spread"]
|
||||
|
||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
||||
ax = axes[row, col]
|
||||
ax.set_facecolor("#0d1117")
|
||||
|
||||
for ti, traj in enumerate(trajectories):
|
||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
||||
for tcp_key in ("left_tcp", "right_tcp"):
|
||||
pts = traj[tcp_key]
|
||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
||||
|
||||
for ci in range(k):
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
left_c = centroids[ci, :, :3]
|
||||
right_c = centroids[ci, :, 3:]
|
||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
||||
for c_pts in (left_c, right_c):
|
||||
ax.plot(
|
||||
c_pts[:, dim_a],
|
||||
c_pts[:, dim_b],
|
||||
color=color,
|
||||
linewidth=lw,
|
||||
alpha=0.95,
|
||||
zorder=10,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[0, dim_a],
|
||||
c_pts[0, dim_b],
|
||||
"o",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[-1, dim_a],
|
||||
c_pts[-1, dim_b],
|
||||
"s",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
|
||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
||||
ax.tick_params(colors="#555", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
||||
if col == 0:
|
||||
ax.set_title(
|
||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
||||
color="white",
|
||||
fontsize=11,
|
||||
pad=10,
|
||||
)
|
||||
else:
|
||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
||||
|
||||
# Cluster size + spread legend on the rightmost panel
|
||||
legend_ax = axes[row, -1]
|
||||
for ci in size_order:
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
spread_cm = spread[ci] * 100
|
||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
||||
legend_ax.legend(
|
||||
loc="upper right",
|
||||
fontsize=7,
|
||||
frameon=True,
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#333",
|
||||
labelcolor="white",
|
||||
handlelength=1.5,
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
||||
color="white",
|
||||
fontsize=16,
|
||||
y=0.98,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id)
|
||||
trajectories = load_episode_trajectories(local)
|
||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
||||
for ci in np.argsort(-cluster_sizes):
|
||||
print(
|
||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"trajectories": trajectories,
|
||||
"labels": labels,
|
||||
"centroids": centroids,
|
||||
"spread": spread,
|
||||
"n_episodes": len(trajectories),
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
||||
render(results, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
|
||||
@@ -16,15 +16,13 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
@@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
@@ -38,6 +38,7 @@ from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -18,7 +18,7 @@ import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
@@ -27,6 +27,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
@@ -31,6 +31,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
DROID_SHARDS = 2048
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -155,7 +155,7 @@ class UploadDataset(PipelineStep):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
@@ -113,8 +113,9 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
|
||||
@@ -82,7 +82,7 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
@@ -16,15 +16,13 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
)
|
||||
@@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -16,11 +16,11 @@
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.feature_utils import combine_feature_dicts
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
@@ -35,6 +35,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
@@ -19,7 +19,7 @@ import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
@@ -28,6 +28,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
@@ -30,6 +30,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@@ -19,8 +19,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -20,9 +20,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -5,8 +5,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -5,8 +5,9 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
@@ -6,8 +6,8 @@ from queue import Empty, Full
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
|
||||
+4
-5
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -76,7 +76,7 @@ dependencies = [
|
||||
"torchvision>=0.21.0,<0.26.0",
|
||||
|
||||
"einops>=0.8.0,<0.9.0",
|
||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"av>=15.0.0,<16.0.0",
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
@@ -119,14 +119,13 @@ gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"unitree-sdk2==1.0.1",
|
||||
# "unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"onnx>=1.16.0,<2.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"lerobot[matplotlib-dep]",
|
||||
"lerobot[pygame-dep]",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ( # noqa: F401
|
||||
|
||||
@@ -39,15 +39,13 @@ import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
from .configs import PolicyServerConfig
|
||||
from .constants import SUPPORTED_POLICIES
|
||||
|
||||
@@ -36,6 +36,16 @@ class DatasetConfig:
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
|
||||
)
|
||||
if len(self.episodes) != len(set(self.episodes)):
|
||||
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
|
||||
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
@@ -47,6 +57,7 @@ class WandBConfig:
|
||||
notes: str | None = None
|
||||
run_id: str | None = None
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
add_tags: bool = True # If True, save configuration as tags in the WandB run.
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
logger = getLogger(__name__)
|
||||
|
||||
@@ -51,7 +51,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
# AND for the evaluation environments.
|
||||
seed: int | None = 1000
|
||||
# Set to True to use deterministic cuDNN algorithms for reproducibility.
|
||||
# This disables cudnn.benchmark and may reduce training speed by ~10-20%.
|
||||
# This disables cudnn.benchmark and may reduce training speed by ~10-20 percent.
|
||||
cudnn_deterministic: bool = False
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
|
||||
@@ -746,7 +746,8 @@ def save_annotations_to_dataset(
|
||||
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
|
||||
):
|
||||
"""Save annotations to LeRobot dataset parquet format."""
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
@@ -840,7 +841,7 @@ def generate_auto_sparse_annotations(
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
|
||||
"""Load annotations from LeRobot dataset parquet files."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.datasets.io_utils import load_episodes
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
|
||||
@@ -24,7 +24,16 @@ import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -32,14 +41,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import packaging.version
|
||||
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v3.0 which is not backward compatible with v2.1.
|
||||
Please, update your dataset to the new format using this command:
|
||||
```
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you already have a converted version uploaded to the hub, then this error might be because of
|
||||
an older version in your local cache. Consider deleting the cached version and retrying.
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
"""
|
||||
|
||||
|
||||
class CompatibilityError(Exception): ...
|
||||
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
if version.major == 2 and version.minor == 1:
|
||||
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ForwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
super().__init__(message)
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
@@ -0,0 +1,517 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
if task in self.tasks.index:
|
||||
return int(self.tasks.loc[task].task_index)
|
||||
else:
|
||||
return None
|
||||
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||
self.tasks.loc[task] = task_idx
|
||||
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
episode_dict.update(episode_metadata)
|
||||
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
This allows users to customize storage organization without modifying the constructor.
|
||||
These settings control how episodes are chunked and how large files can grow before
|
||||
creating new ones.
|
||||
|
||||
Args:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
def get_chunk_settings(self) -> dict[str, int]:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Total episodes: '{self.total_episodes}',\n"
|
||||
f" Total frames: '{self.total_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError(
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
@@ -38,19 +38,22 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
@@ -915,7 +918,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
|
||||
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
||||
"""
|
||||
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import embed_images
|
||||
|
||||
hf_features = get_hf_features_from_features(meta.features)
|
||||
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
|
||||
|
||||
@@ -20,11 +20,9 @@ import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
@@ -0,0 +1,552 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
|
||||
|
||||
Args:
|
||||
features (dict): A LeRobot-style feature dictionary.
|
||||
|
||||
Returns:
|
||||
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If a feature has an unsupported shape.
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
"""Validate that feature names do not contain invalid characters.
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot features dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any feature name contains '/'.
|
||||
"""
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
|
||||
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
joints) or shape (tuple for images).
|
||||
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
||||
or "action").
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {
|
||||
key: ftype
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Construct a single data frame from raw values based on dataset features.
|
||||
|
||||
A "frame" is a dictionary containing all the data for a single timestep,
|
||||
formatted as numpy arrays according to the feature specification.
|
||||
|
||||
Args:
|
||||
ds_features (dict): The LeRobot dataset features dictionary.
|
||||
values (dict): A dictionary of raw values from the hardware/environment.
|
||||
prefix (str): The prefix to filter features by (e.g., "observation"
|
||||
or "action").
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representing a single frame of data.
|
||||
"""
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
continue
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert dataset features to policy features.
|
||||
|
||||
This function transforms the dataset's feature specification into a format
|
||||
that a policy can use, classifying features by type (e.g., visual, state,
|
||||
action) and ensuring correct shapes (e.g., channel-first for images).
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot dataset features dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an image feature does not have a 3D shape.
|
||||
"""
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith(ACTION):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
||||
|
||||
Args:
|
||||
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
||||
|
||||
Returns:
|
||||
dict: A single merged feature dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a dtype mismatch for a feature being merged.
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
fps (int): The frames per second of the data.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
|
||||
|
||||
This ensures that adding these delta timestamps to any existing timestamp in
|
||||
the dataset will result in a value that aligns with the dataset's frame rate.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary where values are lists of time
|
||||
deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
tolerance_s (float): The allowed tolerance in seconds.
|
||||
raise_value_error (bool): If True, raises an error on failure.
|
||||
|
||||
Returns:
|
||||
bool: True if all deltas are valid, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
"""Convert delta timestamps in seconds to delta indices in frames.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary of time deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of frame delta indices.
|
||||
"""
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict) -> None:
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
# task is a special required field that's not part of regular features
|
||||
if "task" not in actual_features:
|
||||
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
||||
|
||||
# Remove task from actual_features for regular feature validation
|
||||
actual_features_for_validation = actual_features - {"task"}
|
||||
|
||||
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||
|
||||
common_features = actual_features_for_validation & expected_features
|
||||
for name in common_features:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
|
||||
"""Check for missing or extra features in a frame.
|
||||
|
||||
Args:
|
||||
actual_features (set[str]): The set of feature names present in the frame.
|
||||
expected_features (set[str]): The set of feature names expected in the frame.
|
||||
|
||||
Returns:
|
||||
str: An error message string if there's a mismatch, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(
|
||||
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
||||
) -> str:
|
||||
"""Validate the dtype and shape of a single feature's value.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
feature (dict): The feature specification from the LeRobot features dictionary.
|
||||
value: The value of the feature to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the feature dtype is not supported for validation.
|
||||
"""
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
) -> str:
|
||||
"""Validate a feature that is expected to be a numpy array.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_dtype (str): The expected numpy dtype as a string.
|
||||
expected_shape (list[int]): The expected shape.
|
||||
value (np.ndarray): The numpy array to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(
|
||||
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
||||
) -> str:
|
||||
"""Validate a feature that is expected to be an image or video frame.
|
||||
|
||||
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value (str): The value to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
Ensures the buffer has the required keys, contains at least one frame, and
|
||||
has features consistent with the dataset's specification.
|
||||
|
||||
Args:
|
||||
episode_buffer (dict): The buffer containing data for a single episode.
|
||||
total_episodes (int): The current total number of episodes in the dataset.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the buffer is invalid.
|
||||
NotImplementedError: If the episode index is manually set and doesn't match.
|
||||
"""
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -22,6 +23,8 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -31,7 +34,7 @@ def safe_stop_image_writer(func):
|
||||
dataset = kwargs.get("dataset")
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
@@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
PIL.Image.Image object.
|
||||
|
||||
Side Effects:
|
||||
Prints an error message to the console if the image writing process
|
||||
fails for any reason.
|
||||
Logs an error message if the image writing process fails for any reason.
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
@@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
flatten_dict,
|
||||
serialize_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.utils.utils import SuppressProgressBars
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
total_uncompressed_size = 0
|
||||
for row_group in range(metadata.num_row_groups):
|
||||
rg_metadata = metadata.row_group(row_group)
|
||||
for column in range(rg_metadata.num_columns):
|
||||
col_metadata = rg_metadata.column(column)
|
||||
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||
return total_uncompressed_size / (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes // (1024**2)
|
||||
|
||||
|
||||
def load_nested_dataset(
|
||||
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
|
||||
) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
|
||||
Args:
|
||||
pq_dir: Directory containing parquet files
|
||||
features: Optional features schema to ensure consistent loading of complex types like images
|
||||
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_file_size_in_mb(file_path: Path) -> float:
|
||||
"""Get file size on disk in megabytes.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file.
|
||||
"""
|
||||
file_size_bytes = file_path.stat().st_size
|
||||
return file_size_bytes / (1024**2)
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Embed image bytes into the dataset table before saving to Parquet.
|
||||
|
||||
This function prepares a Hugging Face dataset for serialization by converting
|
||||
image objects into an embedded format that can be stored in Arrow/Parquet.
|
||||
|
||||
Args:
|
||||
dataset (datasets.Dataset): The input dataset, possibly containing image features.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The dataset with images embedded in the table storage.
|
||||
"""
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
"""Load data from a JSON file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Any: The data loaded from the JSON file.
|
||||
"""
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
"""Write data to a JSON file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to write.
|
||||
fpath (Path): The path to the output JSON file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
"""Serialize and write dataset statistics to their standard file path.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None:
|
||||
"""Load dataset statistics and cast numerical values to numpy arrays.
|
||||
|
||||
Returns None if the stats file doesn't exist.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A dictionary of statistics or None if the file is not found.
|
||||
"""
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
||||
path = local_dir / DEFAULT_TASKS_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
tasks.index.name = "task"
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures.
|
||||
|
||||
Args:
|
||||
episodes: HuggingFace Dataset containing episode metadata
|
||||
local_dir: Root directory where the dataset will be stored
|
||||
"""
|
||||
episode_size_mb = get_hf_dataset_size_in_mb(episodes)
|
||||
if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError(
|
||||
f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. "
|
||||
f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. "
|
||||
"This function only supports single-file episode metadata. "
|
||||
)
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
# Select episode features/columns containing references to episode data and videos
|
||||
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||
# This is to speedup access to these data, instead of having to load episode stats.
|
||||
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||
return episodes
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Load an image from a file into a numpy array.
|
||||
|
||||
Args:
|
||||
fpath (str | Path): Path to the image file.
|
||||
dtype (np.dtype): The desired data type of the output array. If floating,
|
||||
pixels are scaled to [0, 1].
|
||||
channel_first (bool): If True, converts the image to (C, H, W) format.
|
||||
Otherwise, it remains in (H, W, C) format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image as a numpy array.
|
||||
"""
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||
|
||||
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
item (dict): Dictionary of items from a dataset.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -23,526 +23,52 @@ from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
INFO_PATH,
|
||||
_validate_feature_names,
|
||||
from lerobot.datasets.compute_stats import compute_episode_stats
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
)
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.io_utils import (
|
||||
embed_images,
|
||||
get_file_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
load_nested_dataset,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
create_lerobot_dataset_card,
|
||||
get_safe_version,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
# Extract value and serialize numpy arrays
|
||||
# because PyArrow's from_pydict function doesn't support numpy arrays
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
if task in self.tasks.index:
|
||||
return int(self.tasks.loc[task].task_index)
|
||||
else:
|
||||
return None
|
||||
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||
self.tasks.loc[task] = task_idx
|
||||
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Buffer episode metadata and write to parquet in batches for efficiency.
|
||||
|
||||
This function accumulates episode metadata in a buffer and flushes it when the buffer
|
||||
reaches the configured size. This reduces I/O overhead by writing multiple episodes
|
||||
at once instead of one row at a time.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert to list format for each value
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
if self.latest_episode is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.episodes is not None and len(self.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
chunk_idx = self.episodes[-1]["meta/episodes/chunk_index"]
|
||||
file_idx = self.episodes[-1]["meta/episodes/file_index"]
|
||||
latest_num_frames = self.episodes[-1]["dataset_to_index"]
|
||||
episode_dict["dataset_from_index"] = [latest_num_frames]
|
||||
episode_dict["dataset_to_index"] = [latest_num_frames + num_frames]
|
||||
|
||||
# When resuming, move to the next file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
else:
|
||||
episode_dict["dataset_from_index"] = [0]
|
||||
episode_dict["dataset_to_index"] = [num_frames]
|
||||
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
else:
|
||||
chunk_idx = self.latest_episode["meta/episodes/chunk_index"][0]
|
||||
file_idx = self.latest_episode["meta/episodes/file_index"][0]
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
latest_size_in_mb = get_file_size_in_mb(Path(latest_path))
|
||||
latest_num_frames = self.latest_episode["episode_index"][0]
|
||||
|
||||
av_size_per_frame = latest_size_in_mb / latest_num_frames if latest_num_frames > 0 else 0.0
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * num_frames >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, flush buffer and prepare new parquet file
|
||||
self._flush_metadata_buffer()
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
self._close_writer()
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
episode_dict["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
episode_dict["meta/episodes/file_index"] = [file_idx]
|
||||
episode_dict["dataset_from_index"] = [self.latest_episode["dataset_to_index"][0]]
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
# Add to buffer
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_index: int,
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
episode_dict.update(episode_metadata)
|
||||
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
This allows users to customize storage organization without modifying the constructor.
|
||||
These settings control how episodes are chunked and how large files can grow before
|
||||
creating new ones.
|
||||
|
||||
Args:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
def get_chunk_settings(self) -> dict[str, int]:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
f"{self.__class__.__name__}({{\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Total episodes: '{self.total_episodes}',\n"
|
||||
f" Total frames: '{self.total_frames}',\n"
|
||||
f" Features: '{feature_keys}',\n"
|
||||
"})',\n"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
metadata_buffer_size: int = 10,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION,
|
||||
fps,
|
||||
features,
|
||||
use_videos,
|
||||
robot_type,
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
@@ -596,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
the dataset from that address and load it, pending your dataset is compliant with
|
||||
codebase_version v3.0. If your dataset has been created before this new format, you will be
|
||||
prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at
|
||||
lerobot/datasets/v30/convert_dataset_v21_to_v30.py.
|
||||
lerobot/scripts/convert_dataset_v21_to_v30.py.
|
||||
|
||||
|
||||
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
||||
@@ -1326,7 +852,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
logger.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
@@ -1365,7 +891,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
@@ -1375,7 +901,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding videos for episode {ep_idx}")
|
||||
logger.info(f"Encoding videos for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
@@ -1605,7 +1131,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
||||
)
|
||||
|
||||
@@ -1683,7 +1209,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
obj.episodes = None
|
||||
@@ -1717,184 +1242,3 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: str | Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import VideoFrame
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: str | Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logger.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Number of Samples: {self.num_frames},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
@@ -1,382 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An online buffer for the online training loop in train.py
|
||||
|
||||
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
||||
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
||||
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
||||
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def _make_memmap_safe(**kwargs) -> np.memmap:
|
||||
"""Make a numpy memmap with checks on available disk space first.
|
||||
|
||||
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
||||
|
||||
For information on dtypes:
|
||||
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
||||
"""
|
||||
if kwargs["mode"].startswith("w"):
|
||||
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
||||
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
||||
available_space = stats.f_bavail * stats.f_frsize # bytes
|
||||
if required_space >= available_space * 0.8:
|
||||
raise RuntimeError(
|
||||
f"You're about to take up {required_space} of {available_space} bytes available."
|
||||
)
|
||||
return np.memmap(**kwargs)
|
||||
|
||||
|
||||
class OnlineBuffer(torch.utils.data.Dataset):
|
||||
"""FIFO data buffer for the online training loop in train.py.
|
||||
|
||||
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
||||
loop in the same way that a LeRobotDataset would be used.
|
||||
|
||||
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
||||
last index, and when you reach the end, wrap around to the start.
|
||||
|
||||
The data is stored in a numpy memmap.
|
||||
"""
|
||||
|
||||
NEXT_INDEX_KEY = "_next_index"
|
||||
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
||||
INDEX_KEY = "index"
|
||||
FRAME_INDEX_KEY = "frame_index"
|
||||
EPISODE_INDEX_KEY = "episode_index"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
IS_PAD_POSTFIX = "_is_pad"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
write_dir: str | Path,
|
||||
data_spec: dict[str, Any] | None,
|
||||
buffer_capacity: int | None,
|
||||
fps: float | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
||||
):
|
||||
"""
|
||||
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
||||
a `write_dir` associated with an existing buffer.
|
||||
|
||||
Args:
|
||||
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
||||
Note that if the files already exist, they are opened in read-write mode (used for training
|
||||
resumption.)
|
||||
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
||||
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
||||
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
||||
class, so you don't need to include them.
|
||||
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
||||
system's available disk space when choosing this.
|
||||
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
||||
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
||||
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
||||
converted to dict[str, np.ndarray] for optimization purposes.
|
||||
|
||||
"""
|
||||
self.set_delta_timestamps(delta_timestamps)
|
||||
self._fps = fps
|
||||
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
||||
# the requested frames. It is only used when `delta_timestamps` is provided.
|
||||
# minus 1e-4 to account for possible numerical error
|
||||
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
||||
self._buffer_capacity = buffer_capacity
|
||||
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
||||
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
||||
self._data = {}
|
||||
for k, v in data_spec.items():
|
||||
self._data[k] = _make_memmap_safe(
|
||||
filename=Path(write_dir) / k,
|
||||
dtype=v["dtype"] if v is not None else None,
|
||||
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
||||
shape=tuple(v["shape"]) if v is not None else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
||||
return self._delta_timestamps
|
||||
|
||||
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
||||
"""Set delta_timestamps converting the values to numpy arrays.
|
||||
|
||||
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
||||
need to be converted into numpy arrays.
|
||||
"""
|
||||
if value is not None:
|
||||
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
||||
)
|
||||
preset_keys = {
|
||||
OnlineBuffer.INDEX_KEY,
|
||||
OnlineBuffer.FRAME_INDEX_KEY,
|
||||
OnlineBuffer.EPISODE_INDEX_KEY,
|
||||
OnlineBuffer.TIMESTAMP_KEY,
|
||||
}
|
||||
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
||||
raise ValueError(
|
||||
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
||||
f"The provided data_spec has {intersection}."
|
||||
)
|
||||
complete_data_spec = {
|
||||
# _next_index will be a pointer to the next index that we should start filling from when we add
|
||||
# more data.
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
||||
|
||||
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
||||
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
||||
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
||||
|
||||
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
||||
will be done in place!
|
||||
"""
|
||||
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
||||
raise ValueError(f"Missing data keys: {missing_keys}")
|
||||
new_data_length = len(data[self.data_keys[0]])
|
||||
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
||||
raise ValueError("All data items should have the same length")
|
||||
|
||||
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
||||
|
||||
# Sanity check to make sure that the new data indices start from 0.
|
||||
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
|
||||
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
||||
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
||||
for k in self.data_keys:
|
||||
if n_surplus == 0:
|
||||
slc = slice(next_index, next_index + new_data_length)
|
||||
self._data[k][slc] = data[k]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
||||
else:
|
||||
self._data[k][next_index:] = data[k][:-n_surplus]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
||||
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
||||
if n_surplus == 0:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
||||
else:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
||||
|
||||
@property
|
||||
def data_keys(self) -> list[str]:
|
||||
keys = set(self._data)
|
||||
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
||||
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
||||
return sorted(keys)
|
||||
|
||||
@property
|
||||
def fps(self) -> float | None:
|
||||
return self._fps
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
for k, v in item.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
item_[k] = v
|
||||
elif isinstance(v, np.ndarray):
|
||||
item_[k] = torch.from_numpy(v)
|
||||
else:
|
||||
item_[k] = torch.tensor(v)
|
||||
return item_
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self) or idx < -len(self):
|
||||
raise IndexError
|
||||
|
||||
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
||||
|
||||
if self.delta_timestamps is None:
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
||||
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
||||
episode_data_indices = np.where(
|
||||
np.bitwise_and(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
# Get timestamps used as query to retrieve data of previous/future frames.
|
||||
query_ts = current_ts + self.delta_timestamps[data_key]
|
||||
|
||||
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
||||
# the episode.
|
||||
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
||||
argmin_ = np.argmin(dist, axis=1)
|
||||
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
||||
|
||||
is_pad = min_ > self.tolerance_s
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
)
|
||||
|
||||
# Load frames for this data key.
|
||||
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
||||
|
||||
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
||||
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
offline_dataset: LeRobotDataset,
|
||||
offline_drop_n_last_frames: int = 0,
|
||||
online_dataset: OnlineBuffer | None = None,
|
||||
online_sampling_ratio: float | None = None,
|
||||
online_drop_n_last_frames: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the sampling weights for the online training dataloader in train.py.
|
||||
|
||||
Args:
|
||||
offline_dataset: The LeRobotDataset used for offline pre-training.
|
||||
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
||||
online_dataset: The OnlineBuffer used in online training.
|
||||
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
||||
online dataset is provided, this value must also be provided.
|
||||
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
||||
dataset.
|
||||
Returns:
|
||||
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
||||
|
||||
Notes to maintainers:
|
||||
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
||||
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
||||
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
||||
is the ability to turn shuffling off.
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.meta.episodes["dataset_from_index"],
|
||||
offline_dataset.meta.episodes["dataset_to_index"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(offline_dataset),),
|
||||
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
||||
)
|
||||
* offline_data_mask
|
||||
)
|
||||
|
||||
if online_dataset is not None and len(online_dataset) > 0:
|
||||
online_data_mask_indices = []
|
||||
episode_indices = online_dataset.get_data_by_key("episode_index")
|
||||
for episode_idx in torch.unique(episode_indices):
|
||||
where_episode = torch.where(episode_indices == episode_idx)
|
||||
start_index = where_episode[0][0]
|
||||
end_index = where_episode[0][-1] + 1
|
||||
online_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
||||
)
|
||||
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
||||
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(online_dataset),),
|
||||
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
||||
)
|
||||
* online_data_mask
|
||||
)
|
||||
|
||||
weights = torch.cat(weights)
|
||||
|
||||
if weights.sum() == 0:
|
||||
weights += 1 / len(weights)
|
||||
else:
|
||||
weights /= weights.sum()
|
||||
|
||||
return weights
|
||||
@@ -17,8 +17,9 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation
|
||||
from lerobot.datasets.feature_utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
@@ -43,11 +44,11 @@ def create_initial_features(
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
# Helper to filter state/action keys based on compiled regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
return any(pat.search(key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
@@ -88,6 +89,8 @@ def aggregate_pipeline_dataset_features(
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
|
||||
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
@@ -119,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
@@ -13,10 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
@@ -39,13 +42,35 @@ class EpisodeAwareSampler:
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -21,16 +22,13 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_delta_indices
|
||||
from lerobot.datasets.io_utils import item_to_torch
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
get_delta_indices,
|
||||
is_float_in_list,
|
||||
item_to_torch,
|
||||
safe_shard,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
@@ -40,6 +38,164 @@ from lerobot.datasets.video_utils import (
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
|
||||
|
||||
+42
-995
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,8 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
@@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
@@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str:
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
@@ -118,7 +120,7 @@ def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
@@ -208,7 +210,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -244,7 +246,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -348,7 +350,7 @@ def decode_video_frames_torchcodec(
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
logger.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
@@ -374,7 +376,7 @@ def decode_video_frames_torchcodec(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
@@ -408,14 +410,14 @@ def encode_video_frames(
|
||||
imgs_dir = Path(imgs_dir)
|
||||
|
||||
if video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
return
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
@@ -508,7 +510,7 @@ def concatenate_video_files(
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
logger.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
@@ -819,7 +821,7 @@ class StreamingVideoEncoder:
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
@@ -841,7 +843,7 @@ class StreamingVideoEncoder:
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
@@ -851,7 +853,7 @@ class StreamingVideoEncoder:
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
logger.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
@@ -863,7 +865,7 @@ class StreamingVideoEncoder:
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
logger.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
@@ -1071,13 +1073,13 @@ class VideoEncodingManager:
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
logger.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
logging.info("Recording stopped. Encoding remaining episodes...")
|
||||
logger.info("Recording stopped. Encoding remaining episodes...")
|
||||
|
||||
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||
end_ep = self.dataset.num_episodes
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
@@ -1094,7 +1096,7 @@ class VideoEncodingManager:
|
||||
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
@@ -1105,8 +1107,8 @@ class VideoEncodingManager:
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -29,7 +29,7 @@ from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
|
||||
@@ -25,7 +25,7 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
|
||||
|
||||
@@ -29,7 +29,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.processor import RobotObservation
|
||||
from lerobot.types import RobotObservation
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.datasets.io_utils import write_json
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict
|
||||
from lerobot.utils.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
|
||||
@@ -23,7 +23,7 @@ import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.datasets.utils import write_json
|
||||
from lerobot.datasets.io_utils import write_json
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ import torch
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
@@ -43,13 +43,14 @@ from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
|
||||
@@ -49,7 +49,7 @@ from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
HF_LEROBOT_HOME,
|
||||
|
||||
@@ -36,7 +36,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
|
||||
@@ -37,7 +37,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
|
||||
@@ -48,8 +48,8 @@ from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ from lerobot.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
from lerobot.utils.device_utils import get_safe_dtype
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
@@ -374,9 +374,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
|
||||
@@ -23,8 +23,8 @@ from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.datasets.feature_utils import build_dataset_frame
|
||||
from lerobot.types import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
|
||||
|
||||
@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
self.vqvae_model.discretized = torch.tensor(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
self.vqvae_model.discretized.fill_(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
|
||||
@@ -38,7 +38,7 @@ from lerobot.processor import (
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_PREFIX,
|
||||
|
||||
@@ -14,13 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from .core import (
|
||||
from lerobot.types import (
|
||||
EnvAction,
|
||||
EnvTransition,
|
||||
PolicyAction,
|
||||
@@ -28,6 +22,13 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
|
||||
@@ -25,9 +25,9 @@ from dataclasses import dataclass, field
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, PolicyAction
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition, PolicyAction
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
|
||||
@@ -23,10 +23,9 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import PolicyAction, RobotAction
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
|
||||
@@ -14,13 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
|
||||
from .converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from .core import RobotAction, RobotObservation
|
||||
from .pipeline import IdentityProcessorStep, RobotProcessorPipeline
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvAction, EnvTransition, PolicyAction
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -75,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Converts numpy action to torch tensor if action exists, otherwise passes through."""
|
||||
from .core import TransitionKey
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
@@ -30,7 +30,8 @@ from lerobot.teleoperators.utils import TeleopEvents
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
|
||||
@@ -26,10 +26,10 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation
|
||||
|
||||
|
||||
|
||||
@@ -46,10 +46,10 @@ from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
# Generic type variables for pipeline input and output.
|
||||
TInput = TypeVar("TInput")
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
@@ -40,7 +41,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, RobotObservation, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
|
||||
@@ -62,7 +62,6 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
@@ -77,6 +76,8 @@ from lerobot.transport.utils import (
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
@@ -86,7 +87,6 @@ from lerobot.utils.transition import (
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ from lerobot.utils.constants import (
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -96,7 +97,6 @@ from lerobot.utils.train_utils import (
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True) if self.cfg.add_tags else None,
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -23,7 +23,7 @@ import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -33,21 +33,40 @@ from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Action feature keys
|
||||
ACTION_LINEAR_VEL = "linear.vel"
|
||||
ACTION_ANGULAR_VEL = "angular.vel"
|
||||
ACTION_LINEAR_VEL = "linear_velocity"
|
||||
ACTION_ANGULAR_VEL = "angular_velocity"
|
||||
|
||||
# Observation feature keys
|
||||
# Observation feature keys — cameras
|
||||
OBS_FRONT = "front"
|
||||
OBS_REAR = "rear"
|
||||
OBS_LINEAR_VEL = "linear.vel"
|
||||
OBS_BATTERY_LEVEL = "battery.level"
|
||||
OBS_ORIENTATION_DEG = "orientation.deg"
|
||||
OBS_GPS_LATITUDE = "gps.latitude"
|
||||
OBS_GPS_LONGITUDE = "gps.longitude"
|
||||
OBS_GPS_SIGNAL = "gps.signal"
|
||||
OBS_SIGNAL_LEVEL = "signal.level"
|
||||
|
||||
# Observation feature keys — telemetry
|
||||
OBS_SPEED = "speed"
|
||||
OBS_BATTERY_LEVEL = "battery_level"
|
||||
OBS_ORIENTATION = "orientation"
|
||||
OBS_GPS_LATITUDE = "gps_latitude"
|
||||
OBS_GPS_LONGITUDE = "gps_longitude"
|
||||
OBS_GPS_SIGNAL = "gps_signal"
|
||||
OBS_SIGNAL_LEVEL = "signal_level"
|
||||
OBS_VIBRATION = "vibration"
|
||||
OBS_LAMP_STATE = "lamp.state"
|
||||
OBS_LAMP = "lamp"
|
||||
|
||||
# Observation feature keys — IMU sensors
|
||||
OBS_ACCELEROMETER_X = "accelerometer_x"
|
||||
OBS_ACCELEROMETER_Y = "accelerometer_y"
|
||||
OBS_ACCELEROMETER_Z = "accelerometer_z"
|
||||
OBS_GYROSCOPE_X = "gyroscope_x"
|
||||
OBS_GYROSCOPE_Y = "gyroscope_y"
|
||||
OBS_GYROSCOPE_Z = "gyroscope_z"
|
||||
OBS_MAGNETOMETER_X = "magnetometer_filtered_x"
|
||||
OBS_MAGNETOMETER_Y = "magnetometer_filtered_y"
|
||||
OBS_MAGNETOMETER_Z = "magnetometer_filtered_z"
|
||||
|
||||
# Observation feature keys — wheel RPMs
|
||||
OBS_WHEEL_RPM_0 = "wheel_rpm_0"
|
||||
OBS_WHEEL_RPM_1 = "wheel_rpm_1"
|
||||
OBS_WHEEL_RPM_2 = "wheel_rpm_2"
|
||||
OBS_WHEEL_RPM_3 = "wheel_rpm_3"
|
||||
|
||||
|
||||
class EarthRoverMiniPlus(Robot):
|
||||
@@ -154,33 +173,60 @@ class EarthRoverMiniPlus(Robot):
|
||||
dict: Observation features with types/shapes:
|
||||
- front: (480, 640, 3) - Front camera RGB image
|
||||
- rear: (480, 640, 3) - Rear camera RGB image
|
||||
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: float - Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: float - GPS latitude coordinate
|
||||
- gps.longitude: float - GPS longitude coordinate
|
||||
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: float - Network signal level (0-1, normalized from 0-5)
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp.state: float - Lamp state (0=off, 1=on)
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x: float - Accelerometer X axis (raw SDK value)
|
||||
- accelerometer_y: float - Accelerometer Y axis (raw SDK value)
|
||||
- accelerometer_z: float - Accelerometer Z axis (raw SDK value)
|
||||
- gyroscope_x: float - Gyroscope X axis (raw SDK value)
|
||||
- gyroscope_y: float - Gyroscope Y axis (raw SDK value)
|
||||
- gyroscope_z: float - Gyroscope Z axis (raw SDK value)
|
||||
- magnetometer_filtered_x: float - Magnetometer X axis (raw SDK value)
|
||||
- magnetometer_filtered_y: float - Magnetometer Y axis (raw SDK value)
|
||||
- magnetometer_filtered_z: float - Magnetometer Z axis (raw SDK value)
|
||||
- wheel_rpm_0: float - Wheel 0 RPM
|
||||
- wheel_rpm_1: float - Wheel 1 RPM
|
||||
- wheel_rpm_2: float - Wheel 2 RPM
|
||||
- wheel_rpm_3: float - Wheel 3 RPM
|
||||
"""
|
||||
return {
|
||||
# Cameras (height, width, channels)
|
||||
OBS_FRONT: (480, 640, 3),
|
||||
OBS_REAR: (480, 640, 3),
|
||||
# Motion state
|
||||
OBS_LINEAR_VEL: float,
|
||||
# Robot state
|
||||
# Telemetry
|
||||
OBS_SPEED: float,
|
||||
OBS_BATTERY_LEVEL: float,
|
||||
OBS_ORIENTATION_DEG: float,
|
||||
# GPS
|
||||
OBS_ORIENTATION: float,
|
||||
OBS_GPS_LATITUDE: float,
|
||||
OBS_GPS_LONGITUDE: float,
|
||||
OBS_GPS_SIGNAL: float,
|
||||
# Sensors
|
||||
OBS_SIGNAL_LEVEL: float,
|
||||
OBS_VIBRATION: float,
|
||||
OBS_LAMP_STATE: float,
|
||||
OBS_LAMP: float,
|
||||
# IMU — accelerometer
|
||||
OBS_ACCELEROMETER_X: float,
|
||||
OBS_ACCELEROMETER_Y: float,
|
||||
OBS_ACCELEROMETER_Z: float,
|
||||
# IMU — gyroscope
|
||||
OBS_GYROSCOPE_X: float,
|
||||
OBS_GYROSCOPE_Y: float,
|
||||
OBS_GYROSCOPE_Z: float,
|
||||
# IMU — magnetometer
|
||||
OBS_MAGNETOMETER_X: float,
|
||||
OBS_MAGNETOMETER_Y: float,
|
||||
OBS_MAGNETOMETER_Z: float,
|
||||
# Wheel RPMs
|
||||
OBS_WHEEL_RPM_0: float,
|
||||
OBS_WHEEL_RPM_1: float,
|
||||
OBS_WHEEL_RPM_2: float,
|
||||
OBS_WHEEL_RPM_3: float,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
@@ -189,8 +235,8 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Returns:
|
||||
dict: Action features with types:
|
||||
- linear.vel: float - Target linear velocity
|
||||
- angular.vel: float - Target angular velocity
|
||||
- linear_velocity: float - Target linear velocity (-1 to 1)
|
||||
- angular_velocity: float - Target angular velocity (-1 to 1)
|
||||
"""
|
||||
return {
|
||||
ACTION_LINEAR_VEL: float,
|
||||
@@ -201,19 +247,29 @@ class EarthRoverMiniPlus(Robot):
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
|
||||
Frames are decoded from base64 and converted from BGR to RGB format.
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
Sensor arrays (accels, gyros, mags, rpms) each contain entries of
|
||||
[values..., timestamp]; the latest reading from each array is used.
|
||||
|
||||
Returns:
|
||||
RobotObservation: Observation containing:
|
||||
- front: Front camera image (480, 640, 3) in RGB format
|
||||
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: GPS latitude coordinate
|
||||
- gps.longitude: GPS longitude coordinate
|
||||
- gps.signal: GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: Vibration sensor reading
|
||||
- lamp.state: Lamp state (0=off, 1=on)
|
||||
- speed: float - Current speed (raw SDK value)
|
||||
- battery_level: float - Battery level (0-100)
|
||||
- orientation: float - Robot orientation in degrees
|
||||
- gps_latitude: float - GPS latitude coordinate
|
||||
- gps_longitude: float - GPS longitude coordinate
|
||||
- gps_signal: float - GPS signal strength (percentage)
|
||||
- signal_level: float - Network signal level (0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp: float - Lamp state (0=off, 1=on)
|
||||
- accelerometer_x/y/z: float - Accelerometer axes (raw SDK value)
|
||||
- gyroscope_x/y/z: float - Gyroscope axes (raw SDK value)
|
||||
- magnetometer_filtered_x/y/z: float - Magnetometer axes (raw SDK value)
|
||||
- wheel_rpm_0/1/2/3: float - Wheel RPMs
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
@@ -235,22 +291,41 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Get robot state from SDK
|
||||
robot_data = self._get_robot_data()
|
||||
|
||||
# Motion state
|
||||
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
|
||||
# Telemetry
|
||||
observation[OBS_SPEED] = float(robot_data["speed"])
|
||||
observation[OBS_BATTERY_LEVEL] = float(robot_data["battery"])
|
||||
observation[OBS_ORIENTATION] = float(robot_data["orientation"])
|
||||
observation[OBS_GPS_LATITUDE] = float(robot_data["latitude"])
|
||||
observation[OBS_GPS_LONGITUDE] = float(robot_data["longitude"])
|
||||
observation[OBS_GPS_SIGNAL] = float(robot_data["gps_signal"])
|
||||
observation[OBS_SIGNAL_LEVEL] = float(robot_data["signal_level"])
|
||||
observation[OBS_VIBRATION] = float(robot_data["vibration"])
|
||||
observation[OBS_LAMP] = float(robot_data["lamp"])
|
||||
|
||||
# Robot state
|
||||
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
|
||||
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
|
||||
# Accelerometer — latest reading from accels array [x, y, z, ts]
|
||||
accel = self._latest_sensor_reading(robot_data, "accels", n_values=3)
|
||||
observation[OBS_ACCELEROMETER_X] = accel[0]
|
||||
observation[OBS_ACCELEROMETER_Y] = accel[1]
|
||||
observation[OBS_ACCELEROMETER_Z] = accel[2]
|
||||
|
||||
# GPS data
|
||||
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
|
||||
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
|
||||
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
|
||||
# Gyroscope — latest reading from gyros array [x, y, z, ts]
|
||||
gyro = self._latest_sensor_reading(robot_data, "gyros", n_values=3)
|
||||
observation[OBS_GYROSCOPE_X] = gyro[0]
|
||||
observation[OBS_GYROSCOPE_Y] = gyro[1]
|
||||
observation[OBS_GYROSCOPE_Z] = gyro[2]
|
||||
|
||||
# Sensors
|
||||
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
|
||||
observation[OBS_VIBRATION] = robot_data["vibration"]
|
||||
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
|
||||
# Magnetometer — latest reading from mags array [x, y, z, ts]
|
||||
mag = self._latest_sensor_reading(robot_data, "mags", n_values=3)
|
||||
observation[OBS_MAGNETOMETER_X] = mag[0]
|
||||
observation[OBS_MAGNETOMETER_Y] = mag[1]
|
||||
observation[OBS_MAGNETOMETER_Z] = mag[2]
|
||||
|
||||
# Wheel RPMs — latest reading from rpms array [w0, w1, w2, w3, ts]
|
||||
rpm = self._latest_sensor_reading(robot_data, "rpms", n_values=4)
|
||||
observation[OBS_WHEEL_RPM_0] = rpm[0]
|
||||
observation[OBS_WHEEL_RPM_1] = rpm[1]
|
||||
observation[OBS_WHEEL_RPM_2] = rpm[2]
|
||||
observation[OBS_WHEEL_RPM_3] = rpm[3]
|
||||
|
||||
return observation
|
||||
|
||||
@@ -260,11 +335,12 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
Args:
|
||||
action: Action dict with keys:
|
||||
- linear.vel: Target linear velocity (-1 to 1)
|
||||
- angular.vel: Target angular velocity (-1 to 1)
|
||||
- linear_velocity: Target linear velocity (-1 to 1)
|
||||
- angular_velocity: Target angular velocity (-1 to 1)
|
||||
|
||||
Returns:
|
||||
RobotAction: The action that was sent (matches action_features keys)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
@@ -272,18 +348,14 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
|
||||
|
||||
# Send command to SDK
|
||||
try:
|
||||
self._send_command_to_sdk(linear, angular)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending action: {e}")
|
||||
|
||||
# Return action in format matching action_features
|
||||
return {
|
||||
ACTION_LINEAR_VEL: linear,
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
@@ -394,11 +466,27 @@ class EarthRoverMiniPlus(Robot):
|
||||
logger.error(f"Error decoding image: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _latest_sensor_reading(robot_data: dict, key: str, n_values: int) -> list[float]:
|
||||
"""Extract the latest sensor reading from an SDK sensor array.
|
||||
|
||||
The SDK returns sensor arrays like ``accels``, ``gyros``, ``mags``,
|
||||
``rpms`` where each entry is ``[value_0, ..., value_n, timestamp]``.
|
||||
This helper returns the *n_values* leading floats from the last entry,
|
||||
falling back to zeros when the key is missing or the array is empty.
|
||||
"""
|
||||
readings = robot_data.get(key)
|
||||
if readings and len(readings) > 0:
|
||||
latest = readings[-1]
|
||||
return [float(v) for v in latest[:n_values]]
|
||||
return [0.0] * n_values
|
||||
|
||||
def _get_robot_data(self) -> dict:
|
||||
"""Get robot telemetry data from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS,
|
||||
and sensor arrays (accels, gyros, mags, rpms):
|
||||
- Current data (if request succeeds)
|
||||
- Cached data (if request fails but cache exists)
|
||||
- Default values (if request fails and no cache exists yet)
|
||||
@@ -420,19 +508,23 @@ class EarthRoverMiniPlus(Robot):
|
||||
# Fallback: use cache or default values
|
||||
if self._last_robot_data is not None:
|
||||
return self._last_robot_data
|
||||
else:
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
}
|
||||
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
"accels": [],
|
||||
"gyros": [],
|
||||
"mags": [],
|
||||
"rpms": [],
|
||||
}
|
||||
|
||||
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
|
||||
"""Send control command to SDK.
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -22,7 +22,7 @@ from functools import cached_property
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Any
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -19,7 +19,7 @@ import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
@@ -19,7 +19,7 @@ from pathlib import Path
|
||||
import draccus
|
||||
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS
|
||||
|
||||
from .config import RobotConfig
|
||||
|
||||
@@ -24,7 +24,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user