mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 132ea975f0 | |||
| 961e0d9bcd | |||
| 6496728025 | |||
| 3b37bd0ca6 | |||
| 8e692e365c | |||
| f617b2c2bf | |||
| c6a51b9b60 | |||
| ab49c71c22 | |||
| 459efef8a0 | |||
| 5568ce7af1 | |||
| be0320a420 | |||
| 5222f3a4a7 | |||
| f9d12db9cf | |||
| 71aacda05e | |||
| e3deff00ad | |||
| 4dfa8cea65 |
@@ -22,6 +22,10 @@ outputs
|
||||
rl
|
||||
media
|
||||
|
||||
# Local virtualenvs (the image provides its own)
|
||||
.venv
|
||||
venv
|
||||
|
||||
|
||||
# Logging
|
||||
logs
|
||||
|
||||
@@ -67,6 +67,8 @@
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: lingbot_va
|
||||
title: LingBot-VA
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
# LingBot-VA
|
||||
|
||||
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
|
||||
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
|
||||
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
|
||||
integration wires LingBot-VA into the standard training, evaluation and processor
|
||||
interfaces.
|
||||
|
||||
## Model Overview
|
||||
|
||||
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
|
||||
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
|
||||
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
|
||||
text conditioning.
|
||||
|
||||
| Component | Class | Role |
|
||||
| ------------------------ | ----------------------- | ----------------------------------------------------------- |
|
||||
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer. |
|
||||
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
|
||||
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
|
||||
|
||||
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
|
||||
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
|
||||
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
|
||||
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=lingbot_va` configuration through LeRobot.
|
||||
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
|
||||
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the LingBot-VA extra:
|
||||
|
||||
```bash
|
||||
pip install -e ".[lingbot_va]"
|
||||
```
|
||||
|
||||
## Checkpoints
|
||||
|
||||
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
|
||||
|
||||
| Variant | LeRobot checkpoint |
|
||||
| ---------------------- | -------------------------------- |
|
||||
| LIBERO-Long post-train | `lerobot/lingbot_va_libero_long` |
|
||||
| RoboTwin post-train | `lerobot/lingbot_va_robotwin` |
|
||||
| Pretrained base | `lerobot/lingbot_va_base` |
|
||||
|
||||
Only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are pulled from
|
||||
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
|
||||
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
|
||||
transformer + VAE fit on a single 24–32 GB GPU.
|
||||
|
||||
## Evaluation (LIBERO)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_libero_long \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero --env.task=libero_10 \
|
||||
--env.observation_height=128 --env.observation_width=128 \
|
||||
--eval.n_episodes=50 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_libero
|
||||
```
|
||||
|
||||
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
|
||||
single-environment eval; use `--eval.batch_size=1`.
|
||||
|
||||
## Evaluation (RoboTwin)
|
||||
|
||||
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack. You can use the benchmark Docker image
|
||||
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
|
||||
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
|
||||
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
|
||||
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
|
||||
executed via CuRobo IK.
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/lingbot_va_robotwin \
|
||||
--policy.device=cuda \
|
||||
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
|
||||
--eval.n_episodes=10 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_robotwin
|
||||
```
|
||||
|
||||
### Saving predicted (imagined) videos
|
||||
|
||||
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
|
||||
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
|
||||
The same flag works for the periodic eval during `lerobot-train`.
|
||||
|
||||
## Training / fine-tuning
|
||||
|
||||
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
|
||||
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
|
||||
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
|
||||
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
|
||||
with a linear-warmup-then-constant schedule (matching upstream).
|
||||
|
||||
Requirements:
|
||||
|
||||
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
|
||||
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
|
||||
- The full 5B DiT does not fit a single 24–32 GB GPU under AdamW; fine-tune with **LoRA**
|
||||
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
|
||||
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/lingbot_va_libero_long --policy.attn_mode=flex \
|
||||
--policy.use_peft=true \
|
||||
--dataset.repo_id=<your LeRobot-format dataset> \
|
||||
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
|
||||
```
|
||||
|
||||
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
|
||||
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
|
||||
|
||||
## Data format (action channels & camera order)
|
||||
|
||||
LingBot-VA is an **end-effector (Cartesian) pose** policy, it predicts EEF poses + gripper, not
|
||||
joint positions. Actions live in a fixed multi-embodiment **30-dim** layout; map your robot's
|
||||
action dimensions into these channels and pad the rest with `0` (`used_action_channel_ids` selects
|
||||
the channels a given checkpoint actually uses):
|
||||
|
||||
| channels | meaning |
|
||||
| -------- | ----------------------------------------------------- |
|
||||
| 0–6 | Left-arm end-effector pose |
|
||||
| 7–13 | Right-arm end-effector pose |
|
||||
| 14–20 | Left-arm joints (unused by the released checkpoints) |
|
||||
| 21–27 | Right-arm joints (unused by the released checkpoints) |
|
||||
| 28 | Left gripper |
|
||||
| 29 | Right gripper |
|
||||
|
||||
- **LIBERO** uses channels `0–6`: a 6-DoF EEF delta (xyz + rotation) + gripper (single arm).
|
||||
- **RoboTwin** uses channels `[0–6, 28, 7–13, 29]`: left EEF (xyz + quaternion) + left gripper +
|
||||
right EEF + right gripper (16 dims). The env converts these poses to joint trajectories via
|
||||
CuRobo IK — joints are never predicted.
|
||||
|
||||
Joint-space datasets (or a different EEF convention) must be remapped into this schema before
|
||||
fine-tuning these checkpoints.
|
||||
|
||||
**Camera order is fixed and order-sensitive**, per-camera latents are concatenated spatially in
|
||||
`obs_cam_keys` order, so the physical camera→slot mapping must match training:
|
||||
|
||||
| benchmark | `obs_cam_keys` (in order) | `camera_layout` |
|
||||
| --------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
|
||||
| LIBERO | `observation.images.image` (agentview / 3rd-person), `observation.images.image2` (eye-in-hand wrist) | `width_concat` (latents concatenated on width) |
|
||||
| RoboTwin | `observation.images.head_camera`, `observation.images.left_camera`, `observation.images.right_camera` | `robotwin_tshape` (full-res head below, two half-res wrists on top) |
|
||||
|
||||
The first camera is the exterior/head view and the rest are wrist views.
|
||||
|
||||
## Inference Hyperparameters (LIBERO)
|
||||
|
||||
| Key | Value |
|
||||
| -------------------------------------- | --------------------------------------------------------------------------------- |
|
||||
| height × width | 128 × 128 |
|
||||
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
|
||||
| action channels used | 0–6 (7-DoF arm + gripper) |
|
||||
| action_per_frame / frame_chunk_size | 4 / 4 |
|
||||
| attn_window | 30 |
|
||||
| video / action denoising steps | 20 / 50 |
|
||||
| guidance_scale / action_guidance_scale | 5 / 1 |
|
||||
| snr_shift / action_snr_shift | 5.0 / 0.05 |
|
||||
|
||||
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
|
||||
|
||||
## Notes
|
||||
|
||||
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
|
||||
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
|
||||
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
|
||||
roughly 18–24 GB of VRAM.
|
||||
|
||||
## License
|
||||
|
||||
LingBot-VA is released under Apache-2.0. See the
|
||||
[upstream repository](https://github.com/Robbyant/lingbot-va).
|
||||
@@ -1,547 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
|
||||
|
||||
This one file is both the orchestrator and the worker:
|
||||
|
||||
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
|
||||
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
|
||||
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
|
||||
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
|
||||
streaming dataset splits shards per node.
|
||||
|
||||
Scenarios (all single-frame / non-SARM):
|
||||
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
|
||||
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
|
||||
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
|
||||
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
|
||||
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
|
||||
|
||||
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
|
||||
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
|
||||
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
|
||||
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
|
||||
sample latency. Results are written as JSON + CSV under ``--out_dir``.
|
||||
|
||||
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
|
||||
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
|
||||
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
|
||||
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
|
||||
|
||||
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
|
||||
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
|
||||
# it; "main" pins the branch directly and skips that check.
|
||||
MOLMO_REVISION = "main"
|
||||
|
||||
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
|
||||
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
|
||||
# ``cpus`` override the CLI defaults for that leg.
|
||||
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
|
||||
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
|
||||
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
|
||||
SCENARIOS = {
|
||||
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
|
||||
"mmap_local_maxworkers": {
|
||||
"kind": "map",
|
||||
"nodes": 1,
|
||||
"mem": "128G",
|
||||
"time": "01:00:00",
|
||||
"num_workers": 30,
|
||||
"cpus": 92,
|
||||
},
|
||||
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
|
||||
}
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
class _TimedStreaming(StreamingLeRobotDataset):
|
||||
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
|
||||
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
|
||||
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
|
||||
library, to keep the dataset itself instrumentation-free."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fetch_s = 0.0
|
||||
self.decode_s = 0.0
|
||||
|
||||
def __iter__(self):
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
t1 = time.perf_counter()
|
||||
sample = self._finalize_sample(row)
|
||||
t2 = time.perf_counter()
|
||||
self.fetch_s += t1 - t0
|
||||
self.decode_s += t2 - t1
|
||||
yield sample
|
||||
|
||||
|
||||
def select_node_episodes(
|
||||
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
|
||||
) -> list[int]:
|
||||
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
|
||||
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
|
||||
episodes = list(range(meta.total_episodes))
|
||||
from_idx = meta.episodes["dataset_from_index"]
|
||||
to_idx = meta.episodes["dataset_to_index"]
|
||||
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
|
||||
if meta.video_keys:
|
||||
file_columns = {
|
||||
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
|
||||
for key in meta.video_keys
|
||||
}
|
||||
else:
|
||||
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
|
||||
episode_file_ids = [
|
||||
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
|
||||
]
|
||||
groups = group_episodes_by_files(episode_file_ids)
|
||||
if len(groups) < num_partitions:
|
||||
groups = [[i] for i in range(len(episodes))]
|
||||
group_lengths = [sum(lengths[i] for i in g) for g in groups]
|
||||
bins = partition_episodes(group_lengths, num_partitions)
|
||||
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
|
||||
return chosen[:cap] if cap and len(chosen) > cap else chosen
|
||||
|
||||
|
||||
def build_dataset(scenario: str, args: argparse.Namespace):
|
||||
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
|
||||
if scenario.startswith("mmap_local"):
|
||||
if not args.local_root:
|
||||
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
|
||||
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
|
||||
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
|
||||
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
|
||||
return dataset, meta, True, {"loaded_episodes": len(episodes)}
|
||||
|
||||
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
|
||||
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
|
||||
dataset = _TimedStreaming(
|
||||
MOLMO_REPO,
|
||||
revision=MOLMO_REVISION,
|
||||
data_files_root=data_files_root,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
max_buffer_input_shards=args.max_buffer_input_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
|
||||
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
|
||||
validate_row_groups=False,
|
||||
)
|
||||
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
|
||||
|
||||
|
||||
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
|
||||
stage = fetch_s + decode_s
|
||||
return {
|
||||
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
|
||||
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
|
||||
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
|
||||
}
|
||||
|
||||
|
||||
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
|
||||
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
|
||||
it = iter(dataset)
|
||||
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
|
||||
next(it)
|
||||
dataset.fetch_s = dataset.decode_s = 0.0
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n_probe):
|
||||
next(it)
|
||||
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
|
||||
|
||||
|
||||
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
|
||||
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
|
||||
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
|
||||
|
||||
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
|
||||
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
|
||||
rng = random.Random(0)
|
||||
n = len(dataset)
|
||||
fetch_s = getitem_s = 0.0
|
||||
warmed = measured = skipped = attempts = 0
|
||||
while measured < n_probe and attempts < (warmup + n_probe) * 10:
|
||||
attempts += 1
|
||||
i = rng.randrange(n)
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
dataset.get_raw_item(i)
|
||||
t1 = time.perf_counter()
|
||||
dataset[i]
|
||||
t2 = time.perf_counter()
|
||||
except Exception:
|
||||
skipped += 1
|
||||
continue
|
||||
if warmed < warmup:
|
||||
warmed += 1
|
||||
continue
|
||||
fetch_s += t1 - t0
|
||||
getitem_s += t2 - t1
|
||||
measured += 1
|
||||
if skipped:
|
||||
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
|
||||
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
|
||||
|
||||
|
||||
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
device = torch.device(args.device)
|
||||
|
||||
dataset, meta, is_map_style, info = build_dataset(scenario, args)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=args.num_workers > 0,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = []
|
||||
samples = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
if i == args.warmup_batches:
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
samples += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
|
||||
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
|
||||
if steady_start is not None and (now - steady_start) >= args.duration_s:
|
||||
break
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
|
||||
if samples == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
|
||||
"(inspect worker logs; try --num_workers 0 to surface the exception)."
|
||||
)
|
||||
|
||||
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
|
||||
# decodes video in the main process, which must stay decode-clean until the workers have forked
|
||||
# (decoding before fork corrupts the workers' torchcodec state).
|
||||
del loader
|
||||
if is_map_style:
|
||||
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
|
||||
else:
|
||||
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
|
||||
|
||||
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
num_cameras = len(meta.video_keys)
|
||||
results = {
|
||||
"scenario": scenario,
|
||||
"rank": rank,
|
||||
"world_size": world_size,
|
||||
"loader": "map_style" if is_map_style else "streaming",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"episode_pool_size": None if is_map_style else args.episode_pool_size,
|
||||
"max_buffer_input_shards": None
|
||||
if is_map_style
|
||||
else (args.max_buffer_input_shards or args.episode_pool_size),
|
||||
**info,
|
||||
"num_cameras": num_cameras,
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"samples_measured": samples,
|
||||
"steady_window_s": round(steady_elapsed_s, 2),
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
|
||||
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
|
||||
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
|
||||
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
|
||||
if steady_elapsed_s
|
||||
else 0.0,
|
||||
**fetch_decode,
|
||||
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
print(json.dumps(results, indent=2), flush=True)
|
||||
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
|
||||
|
||||
|
||||
def submit_chain(args: argparse.Namespace) -> None:
|
||||
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
|
||||
|
||||
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
|
||||
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
|
||||
"""
|
||||
this_file = Path(__file__).resolve()
|
||||
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
|
||||
logs = Path(repo_dir) / "logs"
|
||||
logs.mkdir(exist_ok=True)
|
||||
run = f"conda run --no-capture-output -n {args.conda_env} python"
|
||||
common = (
|
||||
f"--batch_size {args.batch_size} "
|
||||
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
|
||||
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
|
||||
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
|
||||
)
|
||||
if args.max_buffer_input_shards is not None:
|
||||
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
|
||||
if args.local_root:
|
||||
common += f" --local_root {args.local_root}"
|
||||
env_prefix = "export TOKENIZERS_PARALLELISM=false"
|
||||
sched = []
|
||||
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
|
||||
if os.environ.get(env):
|
||||
sched.append(f"{opt}={os.environ[env]}")
|
||||
|
||||
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
|
||||
prev = ""
|
||||
for scenario in selected:
|
||||
cfg = SCENARIOS[scenario]
|
||||
nw = cfg.get("num_workers", args.num_workers)
|
||||
cpus = cfg.get("cpus", nw + 4)
|
||||
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
|
||||
if cfg["nodes"] > 1:
|
||||
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
|
||||
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
|
||||
body = f"srun --export=ALL bash -c '{inner}'"
|
||||
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
|
||||
else:
|
||||
body = f"cd {repo_dir} && {env_prefix} && {worker}"
|
||||
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
|
||||
cmd = [
|
||||
"sbatch",
|
||||
"--parsable",
|
||||
f"--job-name=dlbench_{scenario}",
|
||||
*node_flags,
|
||||
f"--cpus-per-task={cpus}",
|
||||
f"--mem={cfg['mem']}",
|
||||
f"--time={cfg['time']}",
|
||||
f"--output={logs}/%x-%j.out",
|
||||
*sched,
|
||||
]
|
||||
if prev:
|
||||
cmd.append(f"--dependency=afterany:{prev}")
|
||||
cmd += ["--wrap", body]
|
||||
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
|
||||
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
|
||||
prev = jid
|
||||
|
||||
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument(
|
||||
"--scenario",
|
||||
choices=list(SCENARIOS),
|
||||
default=None,
|
||||
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--scenarios",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
|
||||
)
|
||||
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
|
||||
p.add_argument(
|
||||
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
|
||||
)
|
||||
p.add_argument("--partition_index", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=64)
|
||||
p.add_argument("--num_workers", type=int, default=8)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
p.add_argument(
|
||||
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_buffer_input_shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Concurrently-live random episodes feeding the pool after reshard() "
|
||||
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
|
||||
)
|
||||
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
|
||||
p.add_argument(
|
||||
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
|
||||
)
|
||||
p.add_argument(
|
||||
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
|
||||
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.scenario is None:
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
|
||||
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
submit_chain(args)
|
||||
else:
|
||||
run_scenario(args.scenario, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+11
-15
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
|
||||
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
@@ -216,8 +217,9 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -231,9 +233,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -284,6 +286,7 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[lingbot_va]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -333,16 +336,6 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
|
||||
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
|
||||
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
|
||||
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
|
||||
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
@@ -385,6 +378,9 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
|
||||
# torch.nn.functional.
|
||||
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.datasets.byte_index_builder import (
|
||||
build_byte_index_tables,
|
||||
load_existing_file_ids,
|
||||
write_byte_index,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
|
||||
parser.add_argument("--repo-id", required=True)
|
||||
parser.add_argument("--revision", default=None)
|
||||
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
|
||||
parser.add_argument("--no-keyframes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
output = args.output
|
||||
existing = load_existing_file_ids(output)
|
||||
if existing:
|
||||
logger.info("resuming: %s files already indexed", len(existing))
|
||||
|
||||
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
|
||||
meta,
|
||||
args.data_root,
|
||||
include_keyframes=not args.no_keyframes,
|
||||
workers=args.workers,
|
||||
existing_files=existing,
|
||||
max_episodes=args.max_episodes,
|
||||
)
|
||||
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
|
||||
logger.info("wrote byte index to %s", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,10 +39,6 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 1024
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
|
||||
from .mp4_episode_slice import episode_custom_frame_mappings_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeSliceLookup:
|
||||
global_episode_id: int
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
frame_count: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
avg_fps: float
|
||||
|
||||
@property
|
||||
def fetch_bytes(self) -> int:
|
||||
return self.mdat_length
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileLookup:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
class EpisodeByteIndex:
|
||||
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_dir: str | Path | None,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
num_episodes: int,
|
||||
mmap: bool = True,
|
||||
files_table: pa.Table | None = None,
|
||||
episodes_table: pa.Table | None = None,
|
||||
mp4_by_rel: dict[str, Any] | None = None,
|
||||
):
|
||||
self.index_dir = Path(index_dir) if index_dir is not None else None
|
||||
self.video_keys = list(video_keys)
|
||||
self.num_episodes = num_episodes
|
||||
self.num_cameras = len(video_keys)
|
||||
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
|
||||
self._mp4_by_rel = mp4_by_rel
|
||||
self._frame_mappings_by_gid: dict[int, bytes] = {}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
if files_table is not None and episodes_table is not None:
|
||||
files_tbl, episodes_tbl = files_table, episodes_table
|
||||
else:
|
||||
if self.index_dir is None:
|
||||
raise ValueError("index_dir or in-memory tables required")
|
||||
files_path = self.index_dir / FILES_NAME
|
||||
episodes_path = self.index_dir / EPISODES_NAME
|
||||
if not files_path.exists() or not episodes_path.exists():
|
||||
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
|
||||
files_tbl = pq.read_table(files_path, memory_map=mmap)
|
||||
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
|
||||
|
||||
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
|
||||
self.build_time_s = time.perf_counter() - t0
|
||||
self.load_time_s = self.build_time_s
|
||||
|
||||
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
|
||||
def col(tbl, name: str):
|
||||
array = tbl.column(name).combine_chunks()
|
||||
if pa.types.is_boolean(array.type):
|
||||
return array.to_numpy(zero_copy_only=False)
|
||||
return array.to_numpy()
|
||||
|
||||
self.file_id = col(files_tbl, "file_id")
|
||||
self.file_path = files_tbl.column("file_path").to_pylist()
|
||||
self.file_size = col(files_tbl, "file_size")
|
||||
self.moov_offset = col(files_tbl, "moov_offset")
|
||||
self.moov_length = col(files_tbl, "moov_length")
|
||||
self.header_length = col(files_tbl, "header_length")
|
||||
self.faststart = col(files_tbl, "faststart")
|
||||
self.file_avg_fps = col(files_tbl, "avg_fps")
|
||||
self.codec = files_tbl.column("codec").to_pylist()
|
||||
|
||||
ep = episodes_tbl
|
||||
n = len(ep)
|
||||
gid = col(ep, "global_episode_id")
|
||||
order = np.argsort(gid)
|
||||
self._global_episode_id = gid[order]
|
||||
self._episode_index = col(ep, "episode_index")[order]
|
||||
self._camera_index = col(ep, "camera_index")[order]
|
||||
self._file_id = col(ep, "file_id")[order]
|
||||
self._mdat_offset = col(ep, "mdat_offset")[order]
|
||||
self._mdat_length = col(ep, "mdat_length")[order]
|
||||
self._frame_count = col(ep, "frame_count")[order]
|
||||
self._first_pts = col(ep, "first_pts")[order]
|
||||
self._last_pts = col(ep, "last_pts")[order]
|
||||
|
||||
expected = self.num_episodes * self.num_cameras
|
||||
if n != expected:
|
||||
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
|
||||
|
||||
if self.index_dir is not None:
|
||||
keyframes_path = self.index_dir / KEYFRAMES_NAME
|
||||
if keyframes_path.exists():
|
||||
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
|
||||
self._keyframes_rows = len(kf_tbl)
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
|
||||
self.resident_bytes = int(
|
||||
self._global_episode_id.nbytes
|
||||
+ self._file_id.nbytes
|
||||
+ self._mdat_offset.nbytes
|
||||
+ self._mdat_length.nbytes
|
||||
+ self.file_size.nbytes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
|
||||
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
|
||||
|
||||
@classmethod
|
||||
def from_memory_build(
|
||||
cls,
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = True,
|
||||
) -> EpisodeByteIndex:
|
||||
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
return build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
include_frame_mappings_cache=include_frame_mappings_cache,
|
||||
)
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
row = int(gid)
|
||||
if row < 0 or row >= len(self._global_episode_id):
|
||||
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
|
||||
file_id = int(self._file_id[row])
|
||||
return EpisodeSliceLookup(
|
||||
global_episode_id=gid,
|
||||
file_id=file_id,
|
||||
mdat_offset=int(self._mdat_offset[row]),
|
||||
mdat_length=int(self._mdat_length[row]),
|
||||
frame_count=int(self._frame_count[row]),
|
||||
first_pts=float(self._first_pts[row]),
|
||||
last_pts=float(self._last_pts[row]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> FileLookup:
|
||||
return FileLookup(
|
||||
file_id=file_id,
|
||||
file_path=self.file_path[file_id],
|
||||
file_size=int(self.file_size[file_id]),
|
||||
moov_offset=int(self.moov_offset[file_id]),
|
||||
moov_length=int(self.moov_length[file_id]),
|
||||
header_length=int(self.header_length[file_id]),
|
||||
faststart=bool(self.faststart[file_id]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
codec=self.codec[file_id],
|
||||
)
|
||||
|
||||
def header_byte_range(self, file_id: int) -> tuple[int, int]:
|
||||
length = int(self.header_length[file_id])
|
||||
return 0, max(0, length - 1)
|
||||
|
||||
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
cached = self._frame_mappings_by_gid.get(gid)
|
||||
if cached is not None:
|
||||
return cached
|
||||
if self._mp4_by_rel is None:
|
||||
return None
|
||||
lookup = self.lookup(episode_index, camera_key)
|
||||
rel = self.file_path[lookup.file_id]
|
||||
mp4_index = self._mp4_by_rel.get(rel)
|
||||
if mp4_index is None:
|
||||
return None
|
||||
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
|
||||
self._frame_mappings_by_gid[gid] = payload
|
||||
return payload
|
||||
|
||||
def stats_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"load_time_s": self.load_time_s,
|
||||
"build_time_s": self.build_time_s,
|
||||
"resident_bytes": self.resident_bytes,
|
||||
"frame_mappings_cached": len(self._frame_mappings_by_gid),
|
||||
"mp4_indices_cached": len(self._mp4_by_rel or {}),
|
||||
"num_files": len(self.file_path),
|
||||
"num_episode_rows": len(self._global_episode_id),
|
||||
}
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .mp4_episode_slice import (
|
||||
HEADER_PROBE_BYTES,
|
||||
MAX_HEADER_PROBE_BYTES,
|
||||
average_fps_from_index,
|
||||
episode_keyframes,
|
||||
parse_mp4_file_layout,
|
||||
parse_mp4_index,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BYTE_INDEX_DIR = "meta/byte_index"
|
||||
FILES_NAME = "files.parquet"
|
||||
EPISODES_NAME = "episodes.parquet"
|
||||
KEYFRAMES_NAME = "keyframes.parquet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexedFile:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
def fetch_header_bytes(path: str, file_size: int) -> bytes:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
probe = HEADER_PROBE_BYTES
|
||||
while True:
|
||||
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
|
||||
header = f.read(min(probe, file_size))
|
||||
try:
|
||||
parse_mp4_file_layout(header, file_size)
|
||||
return header
|
||||
except ValueError as exc:
|
||||
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
|
||||
raise
|
||||
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
|
||||
|
||||
|
||||
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
file_size = fs.info(path)["size"]
|
||||
header = fetch_header_bytes(path, file_size)
|
||||
layout = parse_mp4_file_layout(header, file_size)
|
||||
if not layout.faststart:
|
||||
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
|
||||
mp4_index = parse_mp4_index(header, file_size)
|
||||
indexed = IndexedFile(
|
||||
file_id=-1,
|
||||
file_path=rel_path or path,
|
||||
file_size=file_size,
|
||||
moov_offset=layout.moov_offset,
|
||||
moov_length=layout.moov_length,
|
||||
header_length=layout.header_end,
|
||||
faststart=layout.faststart,
|
||||
avg_fps=average_fps_from_index(mp4_index),
|
||||
codec=layout.codec,
|
||||
)
|
||||
return indexed, mp4_index
|
||||
|
||||
|
||||
def build_byte_index_tables(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
file_paths: list[str] | None = None,
|
||||
include_keyframes: bool = True,
|
||||
workers: int = 8,
|
||||
existing_files: dict[str, int] | None = None,
|
||||
max_episodes: int | None = None,
|
||||
return_mp4_indices: bool = False,
|
||||
complete_files_table: bool = False,
|
||||
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
|
||||
"""Build files/episodes/(optional keyframes) Arrow tables."""
|
||||
video_keys = list(meta.video_keys)
|
||||
n_cams = len(video_keys)
|
||||
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
|
||||
rel_paths: set[str] = set()
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
|
||||
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
|
||||
if file_paths is None:
|
||||
file_paths = list(path_by_rel.values())
|
||||
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
|
||||
|
||||
existing_files = existing_files or {}
|
||||
file_meta_by_rel: dict[str, dict[str, Any]] = {}
|
||||
mp4_by_rel: dict[str, Any] = {}
|
||||
next_file_id = max(existing_files.values(), default=-1) + 1
|
||||
|
||||
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
|
||||
if to_index:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
indexed, mp4_index = fut.result()
|
||||
indexed.file_id = next_file_id
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
file_meta_by_rel[rel] = {
|
||||
"file_id": indexed.file_id,
|
||||
"file_path": rel,
|
||||
"file_size": indexed.file_size,
|
||||
"moov_offset": indexed.moov_offset,
|
||||
"moov_length": indexed.moov_length,
|
||||
"header_length": indexed.header_length,
|
||||
"faststart": indexed.faststart,
|
||||
"avg_fps": indexed.avg_fps,
|
||||
"codec": indexed.codec,
|
||||
}
|
||||
existing_files[rel] = indexed.file_id
|
||||
next_file_id += 1
|
||||
|
||||
missing_rels = {
|
||||
str(meta.get_video_file_path(ep, cam))
|
||||
for ep in range(num_episodes)
|
||||
for cam in video_keys
|
||||
} - set(mp4_by_rel.keys())
|
||||
if missing_rels:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
|
||||
for rel in missing_rels
|
||||
if rel not in mp4_by_rel
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
_, mp4_index = fut.result()
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
|
||||
episode_rows: list[dict[str, Any]] = []
|
||||
keyframe_rows: list[dict[str, Any]] = []
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel = str(meta.get_video_file_path(ep_idx, cam))
|
||||
path = f"{data_root.rstrip('/')}/{rel}"
|
||||
if rel not in existing_files:
|
||||
raise KeyError(f"file not indexed: {rel}")
|
||||
mp4_index = mp4_by_rel[rel]
|
||||
ep = meta.episodes[ep_idx]
|
||||
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
|
||||
span = mp4_index.episode_byte_span(from_ts, to_ts)
|
||||
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
|
||||
mdat_length = span.slice_hi - span.slice_lo + 1
|
||||
episode_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"episode_index": ep_idx,
|
||||
"camera_key": cam,
|
||||
"camera_index": cam_to_idx[cam],
|
||||
"file_id": existing_files[rel],
|
||||
"mdat_offset": span.slice_lo,
|
||||
"mdat_length": mdat_length,
|
||||
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
|
||||
"first_pts": from_ts,
|
||||
"last_pts": to_ts,
|
||||
}
|
||||
)
|
||||
if include_keyframes:
|
||||
timescale = mp4_index.timescale
|
||||
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
|
||||
keyframe_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"pts": int(round(pts_s * timescale)),
|
||||
"byte_offset": byte_off,
|
||||
}
|
||||
)
|
||||
|
||||
referenced_rels = {
|
||||
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
|
||||
}
|
||||
if complete_files_table:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
|
||||
elif to_index:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
|
||||
else:
|
||||
files_table = None
|
||||
episodes_table = pa.Table.from_pylist(episode_rows)
|
||||
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
|
||||
if return_mp4_indices:
|
||||
return files_table, episodes_table, keyframes_table, mp4_by_rel
|
||||
return files_table, episodes_table, keyframes_table
|
||||
|
||||
|
||||
def build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = False,
|
||||
):
|
||||
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index import EpisodeByteIndex
|
||||
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
|
||||
meta,
|
||||
data_root,
|
||||
include_keyframes=False,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
return_mp4_indices=True,
|
||||
complete_files_table=True,
|
||||
)
|
||||
index = EpisodeByteIndex(
|
||||
None,
|
||||
video_keys=list(meta.video_keys),
|
||||
num_episodes=num_episodes,
|
||||
files_table=files_tbl,
|
||||
episodes_table=episodes_tbl,
|
||||
mp4_by_rel=mp4_by_rel,
|
||||
)
|
||||
if include_frame_mappings_cache:
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in meta.video_keys:
|
||||
index.custom_frame_mappings(ep_idx, cam)
|
||||
return index
|
||||
|
||||
|
||||
def write_byte_index(
|
||||
output_dir: Path,
|
||||
files_table: pa.Table | None,
|
||||
episodes_table: pa.Table,
|
||||
keyframes_table: pa.Table | None = None,
|
||||
*,
|
||||
merge_existing: bool = True,
|
||||
) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_path = output_dir / FILES_NAME
|
||||
episodes_path = output_dir / EPISODES_NAME
|
||||
keyframes_path = output_dir / KEYFRAMES_NAME
|
||||
|
||||
if merge_existing and files_path.exists() and files_table is not None:
|
||||
prev = pq.read_table(files_path)
|
||||
files_table = pa.concat_tables([prev, files_table])
|
||||
|
||||
if files_table is not None:
|
||||
pq.write_table(files_table, files_path)
|
||||
|
||||
pq.write_table(episodes_table, episodes_path)
|
||||
if keyframes_table is not None:
|
||||
if merge_existing and keyframes_path.exists():
|
||||
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
|
||||
pq.write_table(keyframes_table, keyframes_path)
|
||||
|
||||
|
||||
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
|
||||
files_path = index_dir / FILES_NAME
|
||||
if not files_path.exists():
|
||||
return {}
|
||||
table = pq.read_table(files_path, columns=["file_id", "file_path"])
|
||||
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
|
||||
@@ -945,17 +945,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
|
||||
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
|
||||
from .mp4_episode_slice import SparseMp4Reader
|
||||
from .torchcodec_utils import open_video_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
bytes_fetched: int = 0
|
||||
full_file_fallbacks: int = 0
|
||||
prefetch_submitted: int = 0
|
||||
prefetch_waits: int = 0
|
||||
mdat_slices: int = 0
|
||||
prefix_fetches: int = 0
|
||||
fetch_to_buffer_s: float = 0.0
|
||||
buffer_to_decoder_s: float = 0.0
|
||||
buffer_hit_decoder_s: float = 0.0
|
||||
decode_frame_s: float = 0.0
|
||||
decode_frames: int = 0
|
||||
|
||||
def merge(self, other: CacheStats) -> None:
|
||||
for name in self.__dataclass_fields__:
|
||||
setattr(self, name, getattr(self, name) + getattr(other, name))
|
||||
|
||||
def stats_dict(self) -> dict[str, int | float]:
|
||||
avg_miss = self.bytes_fetched / max(1, self.misses)
|
||||
return {
|
||||
"byte_cache_hits": self.hits,
|
||||
"byte_cache_misses": self.misses,
|
||||
"byte_cache_bytes_fetched": self.bytes_fetched,
|
||||
"byte_cache_bytes_per_miss": avg_miss,
|
||||
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
|
||||
"byte_cache_prefetch_submitted": self.prefetch_submitted,
|
||||
"byte_cache_prefetch_waits": self.prefetch_waits,
|
||||
"byte_cache_mdat_slices": self.mdat_slices,
|
||||
"byte_cache_prefix_fetches": self.prefix_fetches,
|
||||
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
|
||||
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
|
||||
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
|
||||
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EpisodeEntry:
|
||||
decoders: dict[str, Any] = field(default_factory=dict)
|
||||
ready: threading.Event = field(default_factory=threading.Event)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
class RangeFetcher:
|
||||
"""Sequential byte-range GETs via fsspec."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
|
||||
def fetch(self, lo: int, hi: int) -> bytes:
|
||||
if hi < lo:
|
||||
return b""
|
||||
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
|
||||
f.seek(lo)
|
||||
return f.read(hi - lo + 1)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
|
||||
|
||||
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
byte_index: EpisodeByteIndex,
|
||||
max_bytes: int,
|
||||
*,
|
||||
data_root: str,
|
||||
max_prefetch_workers: int = 4,
|
||||
):
|
||||
if max_bytes <= 0:
|
||||
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
|
||||
self.byte_index = byte_index
|
||||
self.max_bytes = max_bytes
|
||||
self.data_root = data_root.rstrip("/")
|
||||
self._bytes_used = 0
|
||||
self._lock = threading.Lock()
|
||||
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
|
||||
self._header_cache: dict[int, bytes] = {}
|
||||
self._fetcher_cache: dict[int, RangeFetcher] = {}
|
||||
self._episodes: dict[int, _EpisodeEntry] = {}
|
||||
self._stats = CacheStats()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
|
||||
self._futures: dict[int, Future] = {}
|
||||
|
||||
@property
|
||||
def stats(self) -> CacheStats:
|
||||
with self._lock:
|
||||
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
|
||||
|
||||
def submit_prefetch(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
if ep_idx in self._episodes or ep_idx in self._futures:
|
||||
return
|
||||
self._stats.prefetch_submitted += 1
|
||||
fut = self._executor.submit(self._prefetch_episode, ep_idx)
|
||||
self._futures[ep_idx] = fut
|
||||
|
||||
def ensure_ready(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
fut = self._futures.pop(ep_idx, None)
|
||||
if fut is not None:
|
||||
with self._lock:
|
||||
self._stats.prefetch_waits += 1
|
||||
fut.result()
|
||||
entry = self._episodes.get(ep_idx)
|
||||
if entry is None:
|
||||
raise KeyError(f"episode {ep_idx} not prefetched")
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
|
||||
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
|
||||
entry = self._episodes[ep_idx]
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
return entry.decoders[video_key]
|
||||
|
||||
def close(self) -> None:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
def _prefetch_episode(self, ep_idx: int) -> None:
|
||||
entry = _EpisodeEntry()
|
||||
self._episodes[ep_idx] = entry
|
||||
try:
|
||||
for cam in self.byte_index.video_keys:
|
||||
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
|
||||
except Exception as exc:
|
||||
entry.error = exc
|
||||
finally:
|
||||
entry.ready.set()
|
||||
|
||||
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
|
||||
key = (ep_idx, cam)
|
||||
with self._lock:
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.hits += 1
|
||||
payload, _ = cached
|
||||
t0 = time.perf_counter()
|
||||
dec = self._decoder_from_payload(payload, ep_idx, cam)
|
||||
with self._lock:
|
||||
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
|
||||
return dec
|
||||
|
||||
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
|
||||
|
||||
with self._lock:
|
||||
self._stats.misses += 1
|
||||
if payload_bytes > self.MAX_BYTES_PER_MISS:
|
||||
logger.warning(
|
||||
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
|
||||
payload_bytes / 1e6,
|
||||
ep_idx,
|
||||
cam,
|
||||
)
|
||||
self._evict_until(payload_bytes)
|
||||
self._cache[key] = (payload, payload_bytes)
|
||||
self._bytes_used += payload_bytes
|
||||
return dec
|
||||
|
||||
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
|
||||
lookup = self.byte_index.lookup(ep_idx, cam)
|
||||
file_info = self.byte_index.file_lookup(lookup.file_id)
|
||||
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
|
||||
t_fetch = time.perf_counter()
|
||||
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
|
||||
lo = lookup.mdat_offset
|
||||
hi = lo + lookup.mdat_length - 1
|
||||
mdat = fetcher.fetch(lo, hi)
|
||||
fetch_s = time.perf_counter() - t_fetch
|
||||
nbytes = len(header) + len(mdat)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += nbytes
|
||||
self._stats.mdat_slices += 1
|
||||
self._stats.fetch_to_buffer_s += fetch_s
|
||||
|
||||
def lazy_fetch(pos: int, end: int) -> bytes:
|
||||
data = fetcher.fetch(pos, end)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += len(data)
|
||||
return data
|
||||
|
||||
reader = SparseMp4Reader(
|
||||
file_size=file_info.file_size,
|
||||
header=header,
|
||||
mdat_lo=lo,
|
||||
mdat_bytes=mdat,
|
||||
lazy_fetch=lazy_fetch,
|
||||
)
|
||||
t_init = time.perf_counter()
|
||||
dec = self._decoder_from_payload(reader, ep_idx, cam)
|
||||
self._validate_decoder(dec, lookup)
|
||||
init_s = time.perf_counter() - t_init
|
||||
with self._lock:
|
||||
self._stats.buffer_to_decoder_s += init_s
|
||||
self._rewind_payload(reader)
|
||||
return reader, nbytes, dec
|
||||
|
||||
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
|
||||
if file_id not in self._fetcher_cache:
|
||||
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
|
||||
self._fetcher_cache[file_id] = RangeFetcher(path)
|
||||
return self._fetcher_cache[file_id]
|
||||
|
||||
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
|
||||
if file_id in self._header_cache:
|
||||
return self._header_cache[file_id]
|
||||
hi = max(0, header_length - 1)
|
||||
header = fetcher.fetch(0, hi)
|
||||
with self._lock:
|
||||
self._header_cache[file_id] = header
|
||||
self._stats.bytes_fetched += len(header)
|
||||
return header
|
||||
|
||||
def _decoder_from_payload(
|
||||
self, payload: SparseMp4Reader, ep_idx: int, cam: str
|
||||
) -> Any:
|
||||
payload.seek(0)
|
||||
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
|
||||
return open_video_decoder(payload, frame_mappings=mappings)
|
||||
|
||||
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
duration = max(0.01, end - begin)
|
||||
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
|
||||
dec.get_frames_played_at([ts]).data
|
||||
|
||||
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
|
||||
payload.seek(0)
|
||||
|
||||
def _evict_until(self, need: int) -> None:
|
||||
while self._bytes_used + need > self.max_bytes and self._cache:
|
||||
_, (_, size) = self._cache.popitem(last=False)
|
||||
self._bytes_used -= size
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -1,555 +0,0 @@
|
||||
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
|
||||
|
||||
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
|
||||
For streaming we fetch only the file header plus the episode's contiguous mdat span
|
||||
instead of the ``0..episode_end`` prefix.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
KEYFRAME_PAD_S = 0.1
|
||||
HEADER_PROBE_BYTES = 4 * 1024 * 1024
|
||||
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4FileLayout:
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
|
||||
|
||||
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
|
||||
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
|
||||
boxes = list(_iter_boxes(header_bytes))
|
||||
moov_offset = mdat_offset = -1
|
||||
moov_length = mdat_size = 0
|
||||
for off, size, typ, _ in boxes:
|
||||
if typ == b"moov" and moov_offset < 0:
|
||||
moov_offset, moov_length = off, size
|
||||
if typ == b"mdat" and mdat_offset < 0:
|
||||
mdat_offset, mdat_size = off, size
|
||||
if moov_offset < 0:
|
||||
raise ValueError("moov box not found in header probe")
|
||||
if mdat_offset < 0:
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
faststart = moov_offset < mdat_offset
|
||||
header_end = mdat_offset
|
||||
codec = _parse_video_codec(header_bytes)
|
||||
return Mp4FileLayout(
|
||||
file_size=file_size,
|
||||
moov_offset=moov_offset,
|
||||
moov_length=moov_length,
|
||||
header_end=header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
faststart=faststart,
|
||||
codec=codec,
|
||||
)
|
||||
|
||||
|
||||
def _parse_video_codec(header_bytes: bytes) -> str:
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
return "unknown"
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
return "unknown"
|
||||
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
|
||||
if stsd is None or len(stsd) < 12:
|
||||
return "unknown"
|
||||
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
|
||||
if len(stsd) >= 12:
|
||||
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
|
||||
return "unknown"
|
||||
|
||||
|
||||
def average_fps_from_index(index: Mp4VideoIndex) -> float:
|
||||
index.ensure_tables()
|
||||
if index.num_samples < 2:
|
||||
return 30.0
|
||||
duration = index.sample_pts(index.num_samples - 1)
|
||||
if duration <= 0:
|
||||
return 30.0
|
||||
return index.num_samples / duration
|
||||
|
||||
|
||||
def episode_custom_frame_mappings_json(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> bytes:
|
||||
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
|
||||
import json
|
||||
|
||||
index.ensure_tables()
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
hi_idx = min(hi_idx, index.num_samples - 1)
|
||||
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
|
||||
|
||||
sync = set(index.sync_samples)
|
||||
timescale = index.timescale
|
||||
# stts deltas for duration per sample (expand stts entries to per-sample delta)
|
||||
sample_deltas: list[int] = []
|
||||
for count, delta in index.stts:
|
||||
sample_deltas.extend([delta] * count)
|
||||
while len(sample_deltas) < index.num_samples:
|
||||
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
|
||||
|
||||
frames = []
|
||||
for idx in range(lo_idx, hi_idx + 1):
|
||||
frames.append(
|
||||
{
|
||||
"pts": int(round(index._pts[idx] * timescale)),
|
||||
"duration": int(sample_deltas[idx]),
|
||||
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
|
||||
}
|
||||
)
|
||||
return json.dumps({"frames": frames}).encode()
|
||||
|
||||
|
||||
def episode_keyframes(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> list[tuple[float, int]]:
|
||||
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
|
||||
index.ensure_tables()
|
||||
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
if not index.sync_samples:
|
||||
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
out: list[tuple[float, int]] = []
|
||||
for sync_one_based in index.sync_samples:
|
||||
idx = sync_one_based - 1
|
||||
if lo_idx <= idx <= hi_idx:
|
||||
out.append((index.sample_pts(idx), index.sample_offset(idx)))
|
||||
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeByteSpan:
|
||||
"""Absolute file byte ranges to fetch for one episode."""
|
||||
|
||||
file_size: int
|
||||
header_end: int
|
||||
slice_lo: int
|
||||
slice_hi: int
|
||||
|
||||
@property
|
||||
def header_bytes(self) -> tuple[int, int]:
|
||||
return 0, self.header_end - 1
|
||||
|
||||
@property
|
||||
def mdat_bytes(self) -> tuple[int, int]:
|
||||
return self.slice_lo, self.slice_hi
|
||||
|
||||
@property
|
||||
def total_fetch_bytes(self) -> int:
|
||||
header = self.header_end
|
||||
mdat = self.slice_hi - self.slice_lo + 1
|
||||
return header + mdat
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4VideoIndex:
|
||||
file_size: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
timescale: int
|
||||
stts: list[tuple[int, int]]
|
||||
stsz: list[int]
|
||||
stsc: list[tuple[int, int, int]]
|
||||
stco: list[int]
|
||||
sync_samples: list[int]
|
||||
_pts: list[float] = field(default_factory=list, repr=False)
|
||||
_offsets: list[int] = field(default_factory=list, repr=False)
|
||||
|
||||
def ensure_tables(self) -> None:
|
||||
if self._pts:
|
||||
return
|
||||
self._pts = _pts_from_stts(self.stts, self.timescale)
|
||||
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.stsz)
|
||||
|
||||
def sample_pts(self, index: int) -> float:
|
||||
self.ensure_tables()
|
||||
return self._pts[index]
|
||||
|
||||
def sample_offset(self, index: int) -> int:
|
||||
self.ensure_tables()
|
||||
index = max(0, min(index, len(self._offsets) - 1))
|
||||
return self._offsets[index]
|
||||
|
||||
def sample_end(self, index: int) -> int:
|
||||
return self.sample_offset(index) + self.stsz[index]
|
||||
|
||||
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
|
||||
self.ensure_tables()
|
||||
n = self.num_samples
|
||||
if n == 0:
|
||||
raise ValueError("MP4 has no video samples")
|
||||
|
||||
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
|
||||
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
|
||||
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
|
||||
hi_idx = min(hi_idx, n - 1)
|
||||
lo_idx = min(lo_idx, n - 1)
|
||||
|
||||
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
|
||||
|
||||
slice_lo = self.sample_offset(lo_idx)
|
||||
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
|
||||
return EpisodeByteSpan(
|
||||
file_size=self.file_size,
|
||||
header_end=self.header_end,
|
||||
slice_lo=slice_lo,
|
||||
slice_hi=min(slice_hi, self.file_size - 1),
|
||||
)
|
||||
|
||||
|
||||
class SparseMp4Reader(io.BufferedIOBase):
|
||||
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_size: int,
|
||||
header: bytes,
|
||||
mdat_lo: int,
|
||||
mdat_bytes: bytes,
|
||||
lazy_fetch: Callable[[int, int], bytes] | None = None,
|
||||
):
|
||||
self._size = file_size
|
||||
self._header = header
|
||||
self._mdat_lo = mdat_lo
|
||||
self._mdat_hi = mdat_lo + len(mdat_bytes)
|
||||
self._mdat = mdat_bytes
|
||||
self._lazy_fetch = lazy_fetch
|
||||
self._pos = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
return self._pos
|
||||
|
||||
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
||||
if whence == io.SEEK_SET:
|
||||
self._pos = offset
|
||||
elif whence == io.SEEK_CUR:
|
||||
self._pos += offset
|
||||
elif whence == io.SEEK_END:
|
||||
self._pos = self._size + offset
|
||||
else:
|
||||
raise ValueError(f"invalid whence: {whence}")
|
||||
self._pos = max(0, min(self._pos, self._size))
|
||||
return self._pos
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size < 0:
|
||||
size = self._size - self._pos
|
||||
if size <= 0:
|
||||
return b""
|
||||
|
||||
out = bytearray()
|
||||
remaining = size
|
||||
pos = self._pos
|
||||
while remaining > 0 and pos < self._size:
|
||||
chunk = self._read_at(pos, remaining)
|
||||
if not chunk:
|
||||
break
|
||||
out.extend(chunk)
|
||||
pos += len(chunk)
|
||||
remaining -= len(chunk)
|
||||
self._pos = pos
|
||||
return bytes(out)
|
||||
|
||||
def _read_at(self, pos: int, n: int) -> bytes:
|
||||
header_len = len(self._header)
|
||||
if pos < header_len:
|
||||
end = min(pos + n, header_len)
|
||||
return self._header[pos:end]
|
||||
|
||||
if self._mdat_lo <= pos < self._mdat_hi:
|
||||
end = min(pos + n, self._mdat_hi)
|
||||
off = pos - self._mdat_lo
|
||||
return self._mdat[off : off + (end - pos)]
|
||||
|
||||
if self._lazy_fetch is not None:
|
||||
with self._lock:
|
||||
end = min(pos + n, self._size)
|
||||
return self._lazy_fetch(pos, end - 1)
|
||||
|
||||
return b"\x00" * min(n, self._size - pos)
|
||||
|
||||
|
||||
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
|
||||
"""Parse moov sample tables from the file header (faststart layout)."""
|
||||
layout = parse_mp4_file_layout(header_bytes, file_size)
|
||||
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
raise ValueError("moov box not found in MP4 header probe")
|
||||
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
raise ValueError("video trak not found in moov")
|
||||
|
||||
mdhd = _find_box_payload(trak, b"mdhd")
|
||||
if mdhd is None:
|
||||
raise ValueError("mdhd not found")
|
||||
timescale = _parse_mdhd_timescale(mdhd)
|
||||
|
||||
stbl = _find_box_payload(trak, b"stbl")
|
||||
if stbl is None:
|
||||
raise ValueError("stbl not found")
|
||||
|
||||
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
|
||||
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
|
||||
stco_payload = _find_box_payload(stbl, b"stco")
|
||||
co64_payload = _find_box_payload(stbl, b"co64")
|
||||
if stco_payload is not None:
|
||||
stco = _parse_stco(stco_payload)
|
||||
elif co64_payload is not None:
|
||||
stco = _parse_co64(co64_payload)
|
||||
else:
|
||||
raise ValueError("stco/co64 not found")
|
||||
|
||||
stss_payload = _find_box_payload(stbl, b"stss")
|
||||
sync_samples = _parse_stss(stss_payload) if stss_payload else []
|
||||
|
||||
return Mp4VideoIndex(
|
||||
file_size=file_size,
|
||||
header_end=layout.header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
timescale=timescale,
|
||||
stts=stts,
|
||||
stsz=stsz,
|
||||
stsc=stsc,
|
||||
stco=stco,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
|
||||
if offset + 8 > len(data):
|
||||
return None
|
||||
size, typ = struct.unpack_from(">I4s", data, offset)
|
||||
header = 8
|
||||
if size == 1:
|
||||
if offset + 16 > len(data):
|
||||
return None
|
||||
size = struct.unpack_from(">Q", data, offset + 8)[0]
|
||||
header = 16
|
||||
elif size == 0:
|
||||
size = len(data) - offset
|
||||
return size, typ, header
|
||||
|
||||
|
||||
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
|
||||
end = end if end is not None else len(data)
|
||||
off = start
|
||||
while off + 8 <= end:
|
||||
hdr = _box_header(data, off)
|
||||
if hdr is None or hdr[0] < hdr[2]:
|
||||
break
|
||||
size, typ, header = hdr
|
||||
yield off, size, typ, data[off + header : off + size]
|
||||
off += size
|
||||
|
||||
|
||||
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(data):
|
||||
if typ == target:
|
||||
return payload
|
||||
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
|
||||
found = _find_box_payload(payload, target)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(moov):
|
||||
if typ != b"trak":
|
||||
continue
|
||||
hdlr = _find_box_payload(payload, b"hdlr")
|
||||
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
|
||||
for off, size, typ, _ in _iter_boxes(header_bytes):
|
||||
if typ == b"mdat":
|
||||
return off, size
|
||||
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
|
||||
|
||||
def _parse_mdhd_timescale(mdhd: bytes) -> int:
|
||||
version = mdhd[0]
|
||||
if version == 0:
|
||||
return struct.unpack_from(">I", mdhd, 12)[0]
|
||||
return struct.unpack_from(">I", mdhd, 20)[0]
|
||||
|
||||
|
||||
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
|
||||
if stts is None:
|
||||
raise ValueError("stts missing")
|
||||
count = struct.unpack_from(">I", stts, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
sample_count, delta = struct.unpack_from(">II", stts, off)
|
||||
out.append((sample_count, delta))
|
||||
off += 8
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stsz(stsz: bytes | None) -> list[int]:
|
||||
if stsz is None:
|
||||
raise ValueError("stsz missing")
|
||||
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
|
||||
if sample_size != 0:
|
||||
return [sample_size] * sample_count
|
||||
off = 12
|
||||
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
|
||||
|
||||
|
||||
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
|
||||
if stsc is None:
|
||||
raise ValueError("stsc missing")
|
||||
count = struct.unpack_from(">I", stsc, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
|
||||
out.append((first_chunk, samples_per_chunk, sample_desc))
|
||||
off += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stco(stco: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stco, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stco, 8))
|
||||
|
||||
|
||||
def _parse_co64(co64: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", co64, 4)[0]
|
||||
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
|
||||
|
||||
|
||||
def _parse_stss(stss: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stss, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stss, 8))
|
||||
|
||||
|
||||
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
|
||||
pts: list[float] = []
|
||||
t = 0
|
||||
for count, delta in stts:
|
||||
for _ in range(count):
|
||||
pts.append(t / timescale)
|
||||
t += delta
|
||||
return pts
|
||||
|
||||
|
||||
def _sample_byte_offsets(
|
||||
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
|
||||
) -> list[int]:
|
||||
if not stsc:
|
||||
stsc = [(1, len(stsz), 1)]
|
||||
|
||||
offsets: list[int] = []
|
||||
chunk_idx = 0
|
||||
sample_idx = 0
|
||||
sc_idx = 0
|
||||
num_chunks = len(stco)
|
||||
|
||||
while chunk_idx < num_chunks and sample_idx < len(stsz):
|
||||
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
if sc_idx + 1 < len(stsc):
|
||||
next_first = stsc[sc_idx + 1][0]
|
||||
chunks_in_entry = next_first - first_chunk
|
||||
else:
|
||||
chunks_in_entry = num_chunks - chunk_idx
|
||||
|
||||
for _ in range(chunks_in_entry):
|
||||
if chunk_idx >= num_chunks:
|
||||
break
|
||||
offset = stco[chunk_idx]
|
||||
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(stsz):
|
||||
break
|
||||
offsets.append(offset)
|
||||
offset += stsz[sample_idx]
|
||||
sample_idx += 1
|
||||
chunk_idx += 1
|
||||
sc_idx += 1
|
||||
|
||||
if len(offsets) < len(stsz):
|
||||
# Pad with last known offset progression for malformed stsc edge cases.
|
||||
last = offsets[-1] if offsets else 0
|
||||
while len(offsets) < len(stsz):
|
||||
idx = len(offsets)
|
||||
offsets.append(last)
|
||||
last += stsz[idx]
|
||||
|
||||
return offsets
|
||||
|
||||
|
||||
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] < ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return min(lo, len(pts) - 1)
|
||||
|
||||
|
||||
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] <= ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return max(0, lo - 1)
|
||||
|
||||
|
||||
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
|
||||
if not sync_samples:
|
||||
return max(0, sample_idx - 2)
|
||||
# stss stores 1-based sample numbers
|
||||
one_based = sample_idx + 1
|
||||
prev = [s for s in sync_samples if s <= one_based]
|
||||
if prev:
|
||||
return prev[-1] - 1
|
||||
return 0
|
||||
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,49 +0,0 @@
|
||||
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torchcodec import FrameBatch, _core as core
|
||||
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
|
||||
|
||||
|
||||
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
data = json.loads(payload)
|
||||
frames = data["frames"]
|
||||
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
|
||||
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
|
||||
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
|
||||
return pts, key, dur
|
||||
|
||||
|
||||
class VideoDecoderLike:
|
||||
"""Minimal VideoDecoder surface used by episode byte cache."""
|
||||
|
||||
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
|
||||
self._decoder = decoder
|
||||
(
|
||||
self.metadata,
|
||||
self.stream_index,
|
||||
self._begin_stream_seconds,
|
||||
self._end_stream_seconds,
|
||||
self._num_frames,
|
||||
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
|
||||
|
||||
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
|
||||
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
|
||||
|
||||
|
||||
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
|
||||
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
|
||||
if frame_mappings is None:
|
||||
decoder = core.create_from_file_like(source, "approximate")
|
||||
core.add_video_stream(decoder)
|
||||
return VideoDecoderLike(decoder)
|
||||
|
||||
mappings = frame_mappings_tensors(frame_mappings)
|
||||
decoder = core.create_from_file_like(source, "custom_frame_mappings")
|
||||
core.add_video_stream(decoder, custom_frame_mappings=mappings)
|
||||
return VideoDecoderLike(decoder)
|
||||
@@ -273,11 +273,7 @@ class VideoDecoderCache:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
@@ -326,7 +322,6 @@ def decode_video_frames_torchcodec(
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
return_uint8: bool = False,
|
||||
episode_decoder: Any | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
@@ -348,10 +343,8 @@ def decode_video_frames_torchcodec(
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
if episode_decoder is not None:
|
||||
decoder = episode_decoder
|
||||
else:
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
@@ -757,7 +757,7 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
|
||||
task: str = "beat_block_hammer" # single task or comma-separated list
|
||||
fps: int = 25
|
||||
episode_length: int = 300
|
||||
episode_length: int = 1200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
# Available cameras from RoboTwin's aloha-agilex embodiment: head_camera
|
||||
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
# must equal what SAPIEN actually renders.
|
||||
observation_height: int = 240
|
||||
observation_width: int = 320
|
||||
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
|
||||
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
|
||||
action_mode: str = "joint"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.action_mode == "ee":
|
||||
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
for cam in cam_list:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
observation_height=self.observation_height,
|
||||
observation_width=self.observation_width,
|
||||
episode_length=self.episode_length,
|
||||
action_mode=self.action_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
@@ -41,10 +42,117 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
|
||||
"right_camera",
|
||||
)
|
||||
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
|
||||
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
|
||||
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
|
||||
EEF_ACTION_DIM = 16
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
DEFAULT_EPISODE_LENGTH = 300
|
||||
DEFAULT_EPISODE_LENGTH = 1200
|
||||
OFFICIAL_INSTRUCTION_ENV = "LEROBOT_ROBOTWIN_OFFICIAL_INSTRUCTION"
|
||||
OFFICIAL_INSTRUCTION_TYPE_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_TYPE"
|
||||
OFFICIAL_INSTRUCTION_MAX_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_MAX"
|
||||
|
||||
|
||||
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a single-arm predicted delta pose onto the initial pose.
|
||||
|
||||
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
|
||||
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
|
||||
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
|
||||
"""
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
new_r = Rotation.from_quat(new_pose[3:7])
|
||||
init_r = Rotation.from_quat(init_pose[3:7])
|
||||
out_rot = (init_r * new_r).as_quat()
|
||||
out_trans = new_pose[:3] + init_pose[:3]
|
||||
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
|
||||
|
||||
|
||||
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
|
||||
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
|
||||
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
|
||||
out = np.concatenate([left, right])
|
||||
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
|
||||
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
|
||||
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
|
||||
return out
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool = False) -> bool:
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _arm_for_block(block: Any) -> str:
|
||||
return "left" if float(block.get_pose().p[0]) < 0 else "right"
|
||||
|
||||
|
||||
def _robotwin_blocks_episode_info(task_name: str, env: Any) -> dict[str, str] | None:
|
||||
"""Infer the episode-info dict used by RoboTwin's official instruction generator for block ranking."""
|
||||
if task_name == "blocks_ranking_rgb":
|
||||
return {
|
||||
"{A}": "red block",
|
||||
"{B}": "green block",
|
||||
"{C}": "blue block",
|
||||
"{a}": _arm_for_block(env.block1),
|
||||
"{b}": _arm_for_block(env.block2),
|
||||
"{c}": _arm_for_block(env.block3),
|
||||
}
|
||||
if task_name == "blocks_ranking_size":
|
||||
return {
|
||||
"{A}": "large block",
|
||||
"{B}": "medium block",
|
||||
"{C}": "small block",
|
||||
"{a}": _arm_for_block(env.block1),
|
||||
"{b}": _arm_for_block(env.block2),
|
||||
"{c}": _arm_for_block(env.block3),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _generate_robotwin_official_instruction(task_name: str, env: Any) -> str:
|
||||
"""Generate language with RoboTwin's official task templates, matching its eval client."""
|
||||
fallback = task_name.replace("_", " ")
|
||||
episode_info = _robotwin_blocks_episode_info(task_name, env)
|
||||
if episode_info is None:
|
||||
logger.warning("Official RoboTwin instruction is not implemented for task=%s; using %r.", task_name, fallback)
|
||||
return fallback
|
||||
|
||||
try:
|
||||
from description.utils.generate_episode_instructions import generate_episode_descriptions
|
||||
except Exception:
|
||||
logger.warning("Failed to import RoboTwin official instruction generator; using %r.", fallback, exc_info=True)
|
||||
return fallback
|
||||
|
||||
instruction_type = os.environ.get(OFFICIAL_INSTRUCTION_TYPE_ENV, "seen")
|
||||
try:
|
||||
max_descriptions = int(os.environ.get(OFFICIAL_INSTRUCTION_MAX_ENV, "1000000"))
|
||||
except ValueError:
|
||||
max_descriptions = 1000000
|
||||
|
||||
results = generate_episode_descriptions(task_name, [episode_info], max_descriptions=max_descriptions)
|
||||
if not results:
|
||||
logger.warning("RoboTwin generated no official instructions for task=%s; using %r.", task_name, fallback)
|
||||
return fallback
|
||||
|
||||
options = results[0].get(instruction_type) or results[0].get("seen") or results[0].get("unseen")
|
||||
if not options:
|
||||
logger.warning(
|
||||
"RoboTwin generated no %s official instructions for task=%s; using %r.",
|
||||
instruction_type,
|
||||
task_name,
|
||||
fallback,
|
||||
)
|
||||
return fallback
|
||||
|
||||
return str(np.random.choice(options))
|
||||
|
||||
|
||||
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
|
||||
DEFAULT_CAMERA_H = 240
|
||||
DEFAULT_CAMERA_W = 320
|
||||
@@ -234,6 +342,7 @@ class RoboTwinEnv(gym.Env):
|
||||
observation_width: int | None = None,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
render_mode: str = "rgb_array",
|
||||
action_mode: str = "joint",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = task_name
|
||||
@@ -241,6 +350,13 @@ class RoboTwinEnv(gym.Env):
|
||||
self.task_description = task_name.replace("_", " ")
|
||||
self.episode_index = episode_index
|
||||
self._reset_stride = n_envs
|
||||
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
|
||||
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
|
||||
if action_mode not in ("joint", "ee"):
|
||||
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
|
||||
self.action_mode = action_mode
|
||||
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
|
||||
self._init_eef_pose: np.ndarray | None = None
|
||||
self.camera_names = list(camera_names)
|
||||
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
|
||||
# The YAML-driven lookup is deferred to reset() so construction doesn't
|
||||
@@ -271,7 +387,7 @@ class RoboTwinEnv(gym.Env):
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
@@ -317,6 +433,18 @@ class RoboTwinEnv(gym.Env):
|
||||
|
||||
return {"pixels": images, "agent_pos": joint_state}
|
||||
|
||||
def _read_eef_pose(self) -> np.ndarray:
|
||||
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
|
||||
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
|
||||
ep = self._env.get_obs()["endpose"]
|
||||
pose = (
|
||||
list(ep["left_endpose"])
|
||||
+ [ep["left_gripper"]]
|
||||
+ list(ep["right_endpose"])
|
||||
+ [ep["right_gripper"]]
|
||||
)
|
||||
return np.asarray(pose, dtype=np.float64)
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
|
||||
self._ensure_env()
|
||||
super().reset(seed=seed)
|
||||
@@ -330,16 +458,32 @@ class RoboTwinEnv(gym.Env):
|
||||
self.episode_index += self._reset_stride
|
||||
self._step_count = 0
|
||||
|
||||
use_official_instruction = self.task_name in {"blocks_ranking_rgb", "blocks_ranking_size"}
|
||||
if _env_flag(OFFICIAL_INSTRUCTION_ENV, default=use_official_instruction):
|
||||
self.task_description = _generate_robotwin_official_instruction(self.task_name, self._env)
|
||||
if hasattr(self._env, "set_instruction"):
|
||||
self._env.set_instruction(instruction=self.task_description)
|
||||
logger.info("RoboTwin official instruction | task=%s | %s", self.task_name, self.task_description)
|
||||
else:
|
||||
self.task_description = self.task_name.replace("_", " ")
|
||||
|
||||
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
|
||||
if self.action_mode == "ee":
|
||||
self._init_eef_pose = self._read_eef_pose()
|
||||
|
||||
obs = self._get_obs()
|
||||
return obs, {"is_success": False, "task": self.task_name}
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
assert self._env is not None, "step() called before reset()"
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
|
||||
if action.ndim != 1 or action.shape[0] != self._action_dim:
|
||||
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
|
||||
|
||||
with torch.enable_grad():
|
||||
if hasattr(self._env, "take_action"):
|
||||
if self.action_mode == "ee":
|
||||
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
|
||||
self._env.take_action(ee_action, action_type="ee")
|
||||
elif hasattr(self._env, "take_action"):
|
||||
self._env.take_action(action)
|
||||
else:
|
||||
self._env.step(action)
|
||||
@@ -398,6 +542,7 @@ def _make_env_fns(
|
||||
observation_height: int,
|
||||
observation_width: int,
|
||||
episode_length: int,
|
||||
action_mode: str = "joint",
|
||||
) -> list[Callable[[], RoboTwinEnv]]:
|
||||
"""Return n_envs factory callables for a single task."""
|
||||
|
||||
@@ -410,6 +555,7 @@ def _make_env_fns(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
|
||||
return [partial(_make_one, i) for i in range(n_envs)]
|
||||
@@ -423,6 +569,7 @@ def create_robotwin_envs(
|
||||
observation_height: int = DEFAULT_CAMERA_H,
|
||||
observation_width: int = DEFAULT_CAMERA_W,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
action_mode: str = "joint",
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboTwin 2.0 environments.
|
||||
|
||||
@@ -473,6 +620,7 @@ def create_robotwin_envs(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
|
||||
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("constant_with_warmup")
|
||||
@dataclass
|
||||
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
"""Linear warmup followed by a constant learning rate.
|
||||
|
||||
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
|
||||
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int = 1000
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
warmup_steps = self.num_warmup_steps or 0
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
return 1.0
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||
@dataclass
|
||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
|
||||
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
@@ -44,6 +45,7 @@ __all__ = [
|
||||
"EO1Config",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"LingBotVAConfig",
|
||||
"MolmoAct2Config",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
|
||||
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "lingbot_va":
|
||||
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "lingbot_va":
|
||||
return LingBotVAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -448,6 +455,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, LingBotVAConfig):
|
||||
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
processors = make_lingbot_va_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/lingbot_va.mdx
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
|
||||
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
|
||||
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
|
||||
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
|
||||
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
|
||||
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "LingBotVAPolicy":
|
||||
from .modeling_lingbot_va import LingBotVAPolicy
|
||||
|
||||
return LingBotVAPolicy
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for the LingBot-VA policy.
|
||||
|
||||
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
|
||||
video-diffusion stack. It interleaves prediction of future video latents and robot
|
||||
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
|
||||
upstream repository (https://github.com/Robbyant/lingbot-va).
|
||||
|
||||
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
|
||||
and the ``transformer/config.json`` of the released checkpoints.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("lingbot_va")
|
||||
@dataclass
|
||||
class LingBotVAConfig(PreTrainedConfig):
|
||||
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
|
||||
|
||||
# Wan transformer architecture
|
||||
patch_size: tuple[int, int, int] = (1, 2, 2)
|
||||
num_attention_heads: int = 24
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 48
|
||||
out_channels: int = 48
|
||||
action_dim: int = 30
|
||||
text_dim: int = 4096
|
||||
freq_dim: int = 256
|
||||
ffn_dim: int = 14336
|
||||
num_layers: int = 30
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
rope_max_seq_len: int = 1024
|
||||
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
|
||||
attn_mode: str = "torch"
|
||||
|
||||
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
|
||||
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
|
||||
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-base"
|
||||
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
|
||||
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
|
||||
text_encoder_device: str = "cpu"
|
||||
|
||||
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
# Undo the LIBERO env processor's extra horizontal flip to match the model's training orientation.
|
||||
image_hflip: bool = False
|
||||
# Camera latent layout: "width_concat" (cameras concatenated on width; LIBERO) or
|
||||
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
|
||||
camera_layout: str = "width_concat"
|
||||
|
||||
# Inference hyperparameters (LIBERO defaults)
|
||||
n_obs_steps: int = 1
|
||||
height: int = 128
|
||||
width: int = 128
|
||||
action_per_frame: int = 4
|
||||
frame_chunk_size: int = 4
|
||||
attn_window: int = 30
|
||||
num_inference_steps: int = 20
|
||||
video_exec_step: int = -1
|
||||
action_num_inference_steps: int = 50
|
||||
guidance_scale: float = 5.0
|
||||
action_guidance_scale: float = 1.0
|
||||
snr_shift: float = 5.0
|
||||
action_snr_shift: float = 0.05
|
||||
max_sequence_length: int = 512 # UMT5 prompt length
|
||||
|
||||
# Subset of the 30-d action space used by the benchmark (LIBERO = 7-DoF). The action
|
||||
# (un)normalization quantiles live in the checkpoint's ``policy_postprocessor.json``, not here.
|
||||
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
|
||||
|
||||
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
|
||||
save_predicted_video: bool = False
|
||||
|
||||
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
|
||||
# quantile-(un)normalized inside the policy / dedicated processor steps.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
scheduler_warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.attn_mode not in ("torch", "flashattn", "flex"):
|
||||
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
"""Number of single-step actions produced per autoregressive chunk."""
|
||||
return self.frame_chunk_size * self.action_per_frame
|
||||
|
||||
@property
|
||||
def n_action_steps(self) -> int:
|
||||
"""Number of actions executed before refilling (the whole chunk)."""
|
||||
return self.chunk_size
|
||||
|
||||
def validate_features(self) -> None:
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"LingBot-VA requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
|
||||
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
|
||||
|
||||
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,87 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pre/post-processor pipelines for the LingBot-VA policy.
|
||||
|
||||
The preprocessor passes inputs through (IDENTITY) and the postprocessor maps the policy's
|
||||
``[-1, 1]`` actions back to physical units with the built-in ``UnnormalizerProcessorStep``
|
||||
(QUANTILES) using per-channel q01/q99 restored from the checkpoint.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def make_lingbot_va_pre_post_processors(
|
||||
config: LingBotVAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build the pre/post processor pipelines for LingBot-VA."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# Unnormalize actions from [-1, 1] to physical units (QUANTILES) using q01/q99 restored from the checkpoint.
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.QUANTILES},
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -32,6 +32,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
step_entry: dict[str, Any] = {}
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
return pipeline_config
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
config["steps"].append(step_entry)
|
||||
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
pipeline = cls(
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
# 3. Load step state if available
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
|
||||
steps.append(step_instance)
|
||||
|
||||
return steps, override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -105,6 +105,7 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
predicted_latents_callback: Callable[[PreTrainedPolicy], None] | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -134,6 +135,9 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
predicted_latents_callback: Optional callback invoked after every ``select_action`` with the policy
|
||||
itself. World-model policies (e.g. LingBot-VA) stash predicted video latents on
|
||||
``policy.last_predicted_latents``; this lets the caller concatenate chunks and decode once.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -184,6 +188,8 @@ def rollout(
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
if predicted_latents_callback is not None:
|
||||
predicted_latents_callback(policy)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {ACTION: action}
|
||||
@@ -203,12 +209,22 @@ def rollout(
|
||||
# available if none of the envs finished.
|
||||
if "final_info" in info:
|
||||
final_info = info["final_info"]
|
||||
if not isinstance(final_info, dict):
|
||||
raise RuntimeError(
|
||||
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
|
||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||
if isinstance(final_info, dict):
|
||||
is_success = final_info.get("is_success", [False] * env.num_envs)
|
||||
successes = (
|
||||
is_success.tolist()
|
||||
if hasattr(is_success, "tolist")
|
||||
else [bool(is_success)] * env.num_envs
|
||||
)
|
||||
successes = final_info["is_success"].tolist()
|
||||
else:
|
||||
# Gymnasium < 1.0 returns final_info as a per-env sequence/object array,
|
||||
# with entries set to a dict only for envs that just finished.
|
||||
successes = []
|
||||
for item in final_info:
|
||||
if isinstance(item, dict) and "is_success" in item:
|
||||
successes.append(bool(item["is_success"]))
|
||||
else:
|
||||
successes.append(False)
|
||||
elif "is_success" in info:
|
||||
is_success = info["is_success"]
|
||||
successes = (
|
||||
@@ -273,6 +289,7 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
save_predicted_video: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -291,6 +308,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
|
||||
save_predicted_video = save_predicted_video or bool(
|
||||
getattr(getattr(policy, "config", None), "save_predicted_video", False)
|
||||
)
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
exc = ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
@@ -334,6 +356,22 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if save_predicted_video:
|
||||
if not videos_dir:
|
||||
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
|
||||
predicted_video_paths: list[str] = []
|
||||
n_predicted_rendered = 0
|
||||
|
||||
# Collect predicted-video latents across a rollout (world-model policies only). The latents are
|
||||
# concatenated and decoded once after the rollout, matching upstream LingBot-VA's visualization path.
|
||||
def collect_predicted_latents(policy: PreTrainedPolicy):
|
||||
latents = getattr(policy, "last_predicted_latents", None)
|
||||
if latents is not None:
|
||||
pred_latents.append(
|
||||
latents.detach().to("cpu") if hasattr(latents, "detach") else torch.as_tensor(latents).cpu()
|
||||
)
|
||||
policy.last_predicted_latents = None
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
@@ -345,6 +383,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if save_predicted_video:
|
||||
pred_latents: list[torch.Tensor] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
@@ -361,6 +402,7 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
predicted_latents_callback=collect_predicted_latents if save_predicted_video else None,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -426,6 +468,35 @@ def eval_policy(
|
||||
threads.append(thread)
|
||||
n_episodes_rendered += 1
|
||||
|
||||
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
|
||||
if save_predicted_video and len(pred_latents) > 0:
|
||||
predicted_latent = torch.cat(pred_latents, dim=2)
|
||||
decoder = getattr(policy, "decode_predicted_latents", None) or getattr(
|
||||
policy, "_decode_predicted_video", None
|
||||
)
|
||||
if decoder is None:
|
||||
raise AttributeError(
|
||||
"Policy config requested predicted-video saving, but the policy does not expose "
|
||||
"`decode_predicted_latents` or `_decode_predicted_video`."
|
||||
)
|
||||
predicted_video = decoder(predicted_latent)
|
||||
if hasattr(predicted_video, "detach"):
|
||||
predicted_video = predicted_video.detach().to("cpu").numpy()
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
|
||||
predicted_video_paths.append(str(predicted_video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(predicted_video_path),
|
||||
predicted_video,
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
n_predicted_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
@@ -469,6 +540,9 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
info["video_paths"] = video_paths
|
||||
|
||||
if save_predicted_video:
|
||||
info["predicted_video_paths"] = predicted_video_paths
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@@ -600,9 +674,10 @@ class TaskMetrics(TypedDict):
|
||||
max_rewards: list[float]
|
||||
successes: list[bool]
|
||||
video_paths: list[str]
|
||||
predicted_video_paths: list[str]
|
||||
|
||||
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths", "predicted_video_paths")
|
||||
|
||||
|
||||
def eval_one(
|
||||
@@ -643,6 +718,7 @@ def eval_one(
|
||||
max_rewards=[ep["max_reward"] for ep in per_episode],
|
||||
successes=[ep["success"] for ep in per_episode],
|
||||
video_paths=task_result.get("video_paths", []),
|
||||
predicted_video_paths=task_result.get("predicted_video_paths", []),
|
||||
)
|
||||
|
||||
|
||||
@@ -689,6 +765,7 @@ def run_one(
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
metrics.setdefault("predicted_video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
|
||||
|
||||
@@ -742,11 +819,11 @@ def eval_policy_all(
|
||||
_append("sum_rewards", metrics.get("sum_rewards"))
|
||||
_append("max_rewards", metrics.get("max_rewards"))
|
||||
_append("successes", metrics.get("successes"))
|
||||
# video_paths is list-like
|
||||
paths = metrics.get("video_paths", [])
|
||||
if paths:
|
||||
group_acc[group]["video_paths"].extend(paths)
|
||||
overall["video_paths"].extend(paths)
|
||||
for key in ("video_paths", "predicted_video_paths"):
|
||||
paths = metrics.get(key, [])
|
||||
if paths:
|
||||
group_acc[group][key].extend(paths)
|
||||
overall[key].extend(paths)
|
||||
|
||||
# Choose runner (sequential vs threaded)
|
||||
task_runner = partial(
|
||||
@@ -814,6 +891,7 @@ def eval_policy_all(
|
||||
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
|
||||
"n_episodes": len(acc["sum_rewards"]),
|
||||
"video_paths": list(acc["video_paths"]),
|
||||
"predicted_video_paths": list(acc["predicted_video_paths"]),
|
||||
}
|
||||
|
||||
# overall aggregates
|
||||
@@ -825,6 +903,7 @@ def eval_policy_all(
|
||||
"eval_s": time.time() - start_t,
|
||||
"eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
|
||||
"video_paths": list(overall["video_paths"]),
|
||||
"predicted_video_paths": list(overall["predicted_video_paths"]),
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -387,21 +384,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -426,16 +416,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
|
||||
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
|
||||
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
|
||||
# data while decoding all of it). Batches are moved to the device manually in the loop below.
|
||||
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
|
||||
else:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -475,9 +458,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
|
||||
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Acceptance tests for manifest byte-index sidecars.
|
||||
|
||||
Run on a compute node (not login-node):
|
||||
|
||||
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
|
||||
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
|
||||
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("torchcodec")
|
||||
|
||||
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
MAX_EPISODES = 64
|
||||
|
||||
COMPUTE_NODE = pytest.mark.skipif(
|
||||
"login" in socket.gethostname(),
|
||||
reason="run on compute node via srun (see module docstring), not login-node",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def byte_index_dir(tmp_path_factory):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
out = tmp_path_factory.mktemp("byte_index")
|
||||
meta = LeRobotDatasetMetadata(REPO, revision=REV)
|
||||
files, episodes, _ = build_byte_index_tables(
|
||||
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
|
||||
)
|
||||
write_byte_index(out, files, episodes, None, merge_existing=False)
|
||||
return out, meta
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_index_load_fast_and_small(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
|
||||
out, meta = byte_index_dir
|
||||
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
assert index.load_time_s < 1.0
|
||||
assert index.resident_bytes < 1_000_000_000
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_tight_fetch_under_25mb(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
stats = cache.stats.stats_dict()
|
||||
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_in_memory_build_matches_parquet(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
out, meta = byte_index_dir
|
||||
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
for cam in meta.video_keys:
|
||||
a = disk.lookup(ep, cam)
|
||||
b = mem.lookup(ep, cam)
|
||||
assert a.mdat_offset == b.mdat_offset
|
||||
assert a.mdat_length == b.mdat_length
|
||||
assert abs(a.first_pts - b.first_pts) < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_custom_frame_mappings_available(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cam = meta.video_keys[0]
|
||||
ep = MAX_EPISODES // 2
|
||||
payload = index.custom_frame_mappings(ep, cam)
|
||||
assert payload is not None
|
||||
data = json.loads(payload)
|
||||
assert len(data["frames"]) > 10
|
||||
assert any(f["key_frame"] for f in data["frames"])
|
||||
assert all("pts" in f and "duration" in f for f in data["frames"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_metadata_skip_decoder_init(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
assert dec.metadata.num_frames is not None
|
||||
assert dec.metadata.num_frames > 0
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_sparse_decode_produces_frames(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
assert frame.numel() > 0
|
||||
assert float(frame.float().std()) > 1.0
|
||||
@@ -114,30 +114,6 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -24,6 +25,52 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -73,9 +120,10 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -90,17 +138,25 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -108,11 +164,15 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -127,21 +187,31 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -218,11 +288,6 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
||||
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
||||
check = float(left) == float(right)
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -1,430 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
|
||||
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=2,
|
||||
frame_shuffle_buffer_size=8,
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
consumed = [int(next(it)["index"]) for _ in range(30)]
|
||||
state = ds.state_dict()
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
rest = [int(frame["index"]) for frame in resumed_ds]
|
||||
|
||||
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
|
||||
# in-flight buffer contents are skipped on resume (documented datasets behavior):
|
||||
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
|
||||
covered = len(set(consumed) | set(rest))
|
||||
max_in_flight = 2 * 30 + 8
|
||||
assert covered >= total_frames - max_in_flight
|
||||
assert covered + len(consumed) >= total_frames - max_in_flight
|
||||
|
||||
|
||||
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
|
||||
import datasets as hf_datasets
|
||||
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
|
||||
|
||||
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
|
||||
|
||||
|
||||
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
|
||||
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
|
||||
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-reshard"
|
||||
total_episodes = 3
|
||||
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
|
||||
# num_shards can reach total_episodes is per-row-group resharding.
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=90,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
file_to_eps = ds._episode_files()
|
||||
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
|
||||
for (chunk_idx, file_idx), eps in file_to_eps.items():
|
||||
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
|
||||
|
||||
assert ds.num_shards == total_episodes
|
||||
|
||||
|
||||
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
|
||||
a single batch should already span most of the live episodes."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-frac"
|
||||
total_episodes = 8
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=240,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
episode_pool_size=total_episodes,
|
||||
max_buffer_input_shards=total_episodes,
|
||||
)
|
||||
assert ds.max_buffer_input_shards == total_episodes
|
||||
|
||||
batch = 32
|
||||
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
|
||||
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
|
||||
|
||||
|
||||
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
|
||||
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
|
||||
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
|
||||
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
|
||||
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
|
||||
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
|
||||
)
|
||||
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
|
||||
|
||||
|
||||
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
|
||||
stream the whole dataset; the guard must raise instead of silently degrading."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-divis"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
with pytest.raises(ValueError, match="not divisible by world_size"):
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
|
||||
)
|
||||
|
||||
# Bypassing the guard downgrades it to a warning (no raise).
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=3,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
validate_row_groups=False,
|
||||
)
|
||||
assert ds.num_shards == 3
|
||||
Vendored
+3
-22
@@ -17,7 +17,6 @@ from pathlib import Path
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -36,24 +35,6 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
|
||||
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
|
||||
|
||||
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
|
||||
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
|
||||
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
|
||||
single row group instead.
|
||||
"""
|
||||
table = hf_dataset.with_format("arrow")[:]
|
||||
episode_index = np.asarray(hf_dataset["episode_index"])
|
||||
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
|
||||
starts = [0, *boundaries.tolist()]
|
||||
ends = [*boundaries.tolist(), len(episode_index)]
|
||||
with pq.ParquetWriter(str(path), table.schema) as writer:
|
||||
for start, end in zip(starts, ends, strict=True):
|
||||
writer.write_table(table.slice(start, end - start))
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
hf_dataset: Dataset,
|
||||
local_dir: Path,
|
||||
@@ -86,7 +67,7 @@ def write_hf_dataset(
|
||||
# If the dataset is small enough, write it to a single file.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_to_parquet_one_row_group_per_episode(hf_dataset, path)
|
||||
hf_dataset.to_parquet(path)
|
||||
return
|
||||
|
||||
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
|
||||
@@ -133,8 +114,8 @@ def write_hf_dataset(
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the shard to a Parquet file (one row group per episode).
|
||||
_to_parquet_one_row_group_per_episode(dataset_shard, path)
|
||||
# Write the shard to a Parquet file.
|
||||
dataset_shard.to_parquet(path)
|
||||
|
||||
# Update chunk and file indices for the next iteration.
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def make_config(**overrides) -> LingBotVAConfig:
|
||||
kwargs = {"device": "cpu"}
|
||||
kwargs.update(overrides)
|
||||
return LingBotVAConfig(**kwargs)
|
||||
|
||||
|
||||
def test_registered_in_choice_registry() -> None:
|
||||
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
|
||||
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
|
||||
|
||||
|
||||
def test_type_property() -> None:
|
||||
assert make_config().type == "lingbot_va"
|
||||
|
||||
|
||||
def test_chunk_size_and_action_steps() -> None:
|
||||
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
|
||||
assert cfg.chunk_size == 16
|
||||
assert cfg.n_action_steps == 16
|
||||
assert cfg.action_delta_indices == list(range(16))
|
||||
assert cfg.observation_delta_indices is None
|
||||
assert cfg.reward_delta_indices is None
|
||||
|
||||
|
||||
def test_optimizer_and_scheduler_presets() -> None:
|
||||
cfg = make_config()
|
||||
opt = cfg.get_optimizer_preset()
|
||||
assert opt.lr == cfg.optimizer_lr
|
||||
sched = cfg.get_scheduler_preset()
|
||||
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
|
||||
|
||||
|
||||
def test_validate_features_sets_action_feature() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
assert ACTION in cfg.output_features
|
||||
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
|
||||
|
||||
|
||||
def test_validate_features_no_visual_raises() -> None:
|
||||
cfg = make_config()
|
||||
cfg.input_features = {}
|
||||
cfg.output_features = {}
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
cfg.validate_features()
|
||||
|
||||
|
||||
def test_invalid_attn_mode_raises() -> None:
|
||||
with pytest.raises(ValueError, match="attn_mode"):
|
||||
make_config(attn_mode="banana")
|
||||
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
|
||||
def test_make_policy_config_returns_lingbot_va() -> None:
|
||||
cfg = make_policy_config("lingbot_va", device="cpu")
|
||||
assert isinstance(cfg, LingBotVAConfig)
|
||||
|
||||
|
||||
def test_get_policy_class_resolves_lazily() -> None:
|
||||
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
|
||||
pytest.importorskip("diffusers")
|
||||
pytest.importorskip("transformers")
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
|
||||
cls = get_policy_class("lingbot_va")
|
||||
assert cls.name == "lingbot_va"
|
||||
assert cls.config_class is LingBotVAConfig
|
||||
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
|
||||
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
|
||||
FlowMatchScheduler,
|
||||
data_seq_to_patch,
|
||||
get_mesh_id,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
assert sch.timesteps.shape == (20,)
|
||||
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
|
||||
assert torch.all(diffs <= 0) # decreasing
|
||||
|
||||
|
||||
def test_flow_match_scheduler_step_preserves_shape() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.zeros(1, 48, 4, 8, 16)
|
||||
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
|
||||
assert out.shape == sample.shape
|
||||
|
||||
|
||||
def test_flow_match_scheduler_add_noise() -> None:
|
||||
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
|
||||
sch.set_timesteps(20)
|
||||
sample = torch.randn(1, 48, 4, 8, 16)
|
||||
noise = torch.randn_like(sample)
|
||||
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
|
||||
assert noisy.shape == sample.shape
|
||||
|
||||
|
||||
def test_get_mesh_id_latent_shape() -> None:
|
||||
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
|
||||
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
|
||||
|
||||
|
||||
def test_get_mesh_id_action_shape() -> None:
|
||||
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
|
||||
assert grid.shape == (4, 4 * 4 * 1)
|
||||
# Action rows for h/w are sentinel -1.
|
||||
assert torch.all(grid[1] < 0)
|
||||
assert torch.all(grid[2] < 0)
|
||||
|
||||
|
||||
def test_data_seq_to_patch_roundtrip_shape() -> None:
|
||||
b, f, h, w, c = 1, 4, 8, 16, 48
|
||||
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
|
||||
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
|
||||
assert out.shape == (b, c, f, h, w)
|
||||
|
||||
|
||||
def test_training_step_reduces_loss_tiny_flex() -> None:
|
||||
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
|
||||
|
||||
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
import pytest
|
||||
|
||||
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
cfg = LingBotVAConfig(
|
||||
attn_mode="flex",
|
||||
dtype="bfloat16",
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
action_dim=8,
|
||||
text_dim=32,
|
||||
freq_dim=64,
|
||||
ffn_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=24,
|
||||
num_layers=2,
|
||||
frame_chunk_size=2,
|
||||
action_per_frame=4,
|
||||
used_action_channel_ids=[0, 1, 2, 3],
|
||||
obs_cam_keys=[f"{OBS_IMAGES}.image"],
|
||||
device="cuda",
|
||||
)
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
|
||||
cfg.validate_features()
|
||||
|
||||
policy = LingBotVAPolicy(cfg).to("cuda")
|
||||
policy.train()
|
||||
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
|
||||
|
||||
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
|
||||
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
|
||||
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
|
||||
amask = torch.zeros(cfg.action_dim, device="cuda")
|
||||
amask[cfg.used_action_channel_ids] = 1.0
|
||||
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
|
||||
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
|
||||
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
|
||||
opt.step()
|
||||
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline, UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def _make_config() -> LingBotVAConfig:
|
||||
cfg = LingBotVAConfig(device="cpu")
|
||||
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
|
||||
cfg.output_features = {}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def test_make_pre_post_processors_names_and_steps() -> None:
|
||||
cfg = _make_config()
|
||||
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
|
||||
# Actions are unnormalized by the standard built-in quantile unnormalizer.
|
||||
assert any(isinstance(s, UnnormalizerProcessorStep) for s in post.steps)
|
||||
|
||||
|
||||
def test_freshly_built_postprocessor_is_identity() -> None:
|
||||
# Without action stats the quantile unnormalizer is a no-op (identity passthrough): the real
|
||||
# per-benchmark q01/q99 are restored from the saved checkpoint on load, not hardcoded here.
|
||||
cfg = _make_config()
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
normed = torch.tensor([[0.3, -0.5, 1.0, -1.0, 0.0, 0.7, -0.2]])
|
||||
assert torch.allclose(post(normed), normed, atol=1e-6)
|
||||
|
||||
|
||||
def test_postprocessor_quantile_unnormalization() -> None:
|
||||
# QUANTILES unnormalize maps [-1, 1] -> [q01, q99]: -1 -> q01, +1 -> q99.
|
||||
cfg = _make_config()
|
||||
q01 = [-1.0, -0.5, 0.0, -1.0, -1.0, -1.0, -1.0]
|
||||
q99 = [1.0, 0.5, 2.0, 1.0, 1.0, 1.0, 1.0]
|
||||
stats = {ACTION: {"q01": q01, "q99": q99}}
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=stats)
|
||||
out_lo = post(torch.full((1, 7), -1.0))
|
||||
out_hi = post(torch.full((1, 7), 1.0))
|
||||
assert torch.allclose(out_lo, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
assert torch.allclose(out_hi, torch.tensor(q99).unsqueeze(0), atol=1e-4)
|
||||
|
||||
|
||||
def test_postprocessor_stats_survive_save_load(tmp_path) -> None:
|
||||
# Regression guard for the Hub mechanism: the q01/q99 stats live in the saved post-processor
|
||||
# state and must round-trip through save_pretrained / from_pretrained.
|
||||
cfg = _make_config()
|
||||
q01 = [-0.6, -0.8, -0.9, -0.1, -0.15, -0.25, -1.0]
|
||||
q99 = [0.9, 0.85, 0.9, 0.17, 0.18, 0.34, 1.0]
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats={ACTION: {"q01": q01, "q99": q99}})
|
||||
post.save_pretrained(tmp_path)
|
||||
loaded = PolicyProcessorPipeline.from_pretrained(
|
||||
tmp_path,
|
||||
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
out = loaded(torch.full((1, 7), -1.0))
|
||||
assert torch.allclose(out, torch.tensor(q01).unsqueeze(0), atol=1e-4)
|
||||
@@ -24,7 +24,6 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -175,53 +174,6 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -668,178 +620,6 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -1084,8 +1084,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1102,6 +1102,10 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1168,10 +1172,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "diffusers"
|
||||
version = "0.35.2"
|
||||
version = "0.36.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "httpx" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "importlib-metadata" },
|
||||
{ name = "numpy" },
|
||||
@@ -1180,9 +1185,9 @@ dependencies = [
|
||||
{ name = "requests" },
|
||||
{ name = "safetensors" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/03/68/288ca23c7c05c73e87ffe5efffc282400ac9b017f7a9bb03883f4310ea15/diffusers-0.35.2.tar.gz", hash = "sha256:30ecd552303edfcfe1724573c3918a8462ee3ab4d529bdbd4c0045f763affded", size = 3366711, upload-time = "2025-10-15T04:05:17.213Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/45/ccb2e2180ddf475a0f931dac6a50346310e4c464ce3cccb8a65d1fc1e16d/diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0", size = 3795088, upload-time = "2025-12-08T10:14:34.255Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/2e/38d9824f8c6bb048c5ba21c6d4da54c29c162a46b58b3ef907a360a76d3e/diffusers-0.35.2-py3-none-any.whl", hash = "sha256:d50d5e74fdd6dcf55e5c1d304bc52cc7c2659abd1752740d736d7b54078b4db5", size = 4121649, upload-time = "2025-10-15T04:05:14.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/50/281f92cb1f83854dbd79b6e958b3bc5018607e2542971d41604ba7a14b2f/diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98", size = 4597884, upload-time = "2025-12-08T10:14:31.979Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1632,6 +1637,18 @@ http = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ftfy"
|
||||
version = "6.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "future"
|
||||
version = "1.0.0"
|
||||
@@ -1760,7 +1777,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1768,14 +1785,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.14"
|
||||
version = "0.1.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1785,9 +1802,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1877,7 +1894,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1898,10 +1915,7 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -2695,6 +2709,7 @@ all = [
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "grpcio-tools" },
|
||||
{ name = "gym-aloha" },
|
||||
@@ -2703,6 +2718,7 @@ all = [
|
||||
{ name = "hebi-py" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux'" },
|
||||
{ name = "hidapi" },
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "jsonlines" },
|
||||
{ name = "jupyter" },
|
||||
@@ -2876,6 +2892,9 @@ hopejr = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pyserial" },
|
||||
]
|
||||
imageio-dep = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
]
|
||||
intelrealsense = [
|
||||
{ name = "pyrealsense2", marker = "sys_platform != 'darwin'" },
|
||||
{ name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -2900,6 +2919,13 @@ libero = [
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
lingbot-va = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
matplotlib-dep = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "matplotlib" },
|
||||
@@ -3069,16 +3095,18 @@ xvla = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'lingbot-va'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.36.0" },
|
||||
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.37.0" },
|
||||
{ name = "diffusers", marker = "extra == 'lingbot-va'", specifier = ">=0.36.0,<0.37.0" },
|
||||
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
|
||||
{ name = "draccus", specifier = "==0.10.0" },
|
||||
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
|
||||
@@ -3087,16 +3115,18 @@ requires-dist = [
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "ftfy", marker = "extra == 'lingbot-va'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "imageio", extras = ["ffmpeg"], marker = "extra == 'imageio-dep'", specifier = ">=2.34.0,<3.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
@@ -3127,6 +3157,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["eo1"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
@@ -3138,10 +3169,12 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["hardware"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["hilserl"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["hopejr"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["imageio-dep"], marker = "extra == 'lingbot-va'" },
|
||||
{ name = "lerobot", extras = ["intelrealsense"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["kinematics"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["lekiwi"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["libero"], marker = "sys_platform == 'linux' and extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["lingbot-va"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
|
||||
@@ -3198,6 +3231,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'lingbot-va'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
|
||||
@@ -3275,7 +3309,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "imageio-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "lingbot-va", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user