Compare commits

..

16 Commits

Author SHA1 Message Date
Gangwei XU 132ea975f0 fix(lingbot-va): align RoboTwin evaluation (#3784)
Thank you for the RoboTwin fix, and alignment!
2026-06-15 09:52:48 +02:00
Pepijn 961e0d9bcd docs(lingbot_va): condense processor normalization comments
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 12:04:48 +02:00
Pepijn 6496728025 docs(lingbot_va): point checkpoint paths at the lerobot org
The LeRobot-format checkpoints moved from pepijn223/* to lerobot/* (libero_long,
robotwin, base). Update the eval/train --policy.path examples accordingly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:58:31 +02:00
Pepijn 3b37bd0ca6 refactor(lingbot_va): use built-in UnnormalizerProcessorStep for actions
Replace the bespoke LingBotVAActionUnnormalizeStep with the standard
UnnormalizerProcessorStep in QUANTILES mode, which computes the identical
(action + 1) / 2 * (q99 - q01) + q01 mapping. The per-channel q01/q99 are stored
as the step's saved state (a safetensors file) and restored on load; a fresh build
has no action stats so the step is an identity passthrough.

The 3 Hub checkpoints (lerobot/lingbot_va_{libero_long,robotwin,base}) have been
re-uploaded with the new post-processor (policy_postprocessor.json +
*_unnormalizer_processor.safetensors); reloading from the Hub round-trips q01/q99.

- processor_lingbot_va.py: drop the custom step + registry; build the post-processor
  with UnnormalizerProcessorStep (explicit ACTION->QUANTILES norm_map so the
  preprocessor / training path is unchanged).
- tests: assert the built-in step is used, identity-when-no-stats, correct quantile
  unnormalization, and a save_pretrained/from_pretrained stats round-trip.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:57:31 +02:00
Pepijn 8e692e365c docs(lingbot_va): trim provenance comments; default wan path to base repo
- configuration_lingbot_va.py: drop the "──" decorations and the
  "(from transformer/config.json)" note; default wan_pretrained_path to
  robbyant/lingbot-va-base (has the frozen vae/text_encoder/tokenizer subfolders).
- modeling_lingbot_va.py: remove the vendored-code banner and the
  "(upstream wan_va/...)" section-header provenance/dash decorations; condense the
  transformer-dtype comment to one line.

No code changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:47:45 +02:00
Pepijn f617b2c2bf docs(lingbot_va): trim verbose comments
- configuration_lingbot_va.py: condense multi-line field comments to one-liners
  (keep the ── section headers).
- processor_lingbot_va.py: shorten the action-quantile explanation block.
- modeling_lingbot_va.py: drop the bare "# ----" separator rules, keeping the
  one-line section headers.

No code changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:31:05 +02:00
Pepijn c6a51b9b60 refactor(lingbot_va): drop hardcoded action quantiles; source from checkpoint
The LIBERO/RoboTwin action (un)normalization quantiles were hardcoded as module
constants in processor_lingbot_va.py. They are already serialized into each
checkpoint's policy_postprocessor.json (via LingBotVAActionUnnormalizeStep.get_config)
and restored on load by PolicyProcessorPipeline.from_pretrained, so the constants are
dead at eval/load time for the released checkpoints (verified: libero_long/robotwin/base
all carry their quantiles on the Hub).

- Remove LIBERO_ACTION_Q01/Q99, ROBOTWIN_ACTION_Q01/Q99 and _default_action_quantiles.
- make_lingbot_va_pre_post_processors now defaults a fresh (unconverted) build to a
  neutral [-1, 1] mapping (identity rescale); real per-benchmark stats come from the
  saved checkpoint (or postprocessor_overrides), analogous to dataset-stats normalization.
- Update the config doc comment to point at the checkpoint as the source of truth.
- Tests: replace the LIBERO-default assertion with a neutral-default check, and add a
  save_pretrained/from_pretrained round-trip guard for the quantile serialization.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-08 11:22:42 +02:00
Pepijn ab49c71c22 Update pyproject.toml
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 11:17:16 +02:00
Pepijn 459efef8a0 Update pyproject.toml
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 11:16:08 +02:00
Pepijn 5568ce7af1 Update lingbot_va.mdx
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2026-06-08 10:47:34 +02:00
Pepijn be0320a420 Merge branch 'main' into worktree-lingbot-va-port 2026-06-08 10:38:35 +02:00
pepijn223 5222f3a4a7 docs(lingbot_va): document EEF action-channel schema + camera order
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 16:11:35 +02:00
pepijn223 f9d12db9cf fix(lingbot_va): CI quality gate + fast-test collection
- Add tests/policies/lingbot_va/__init__.py so the test files don't clash by basename
  with tests/policies/vla_jepa/* under pytest's default import mode (fast-test collection error).
- Fix vendored typos flagged by the typos hook (pach_scale->patch_scale, total_tolen->
  total_token_len, stablized->stabilized) and a mypy union-attr in RoboTwinEnv._read_eef_pose.
- Apply Prettier formatting to docs/source/lingbot_va.mdx.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:46:37 +02:00
pepijn223 71aacda05e feat(lingbot_va): implement training / fine-tuning (flow-matching loss)
- Implement LingBotVAPolicy.forward(): dual-stream flow-matching training loss
  (latent + action, timestep-weighted, action-masked) ported from upstream train.py;
  VAE-encodes camera clips, UMT5-encodes the task, noises both streams, runs the
  block-causal flex-attention training pass (forward_train).
- training_loss_from_streams() core + _build_training_streams() data prep (action
  scatter into the 30-d space, multi-frame VAE encode incl. robotwin_tshape).
- get_optim_params returns only trainable transformer params (LoRA/PEFT friendly);
  VAE/UMT5 stay frozen. Training needs attn_mode='flex'.
- Add a tiny-config single-training-step test (forward->loss->backward->AdamW) and a
  Training/fine-tuning section in the docs.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:38:41 +02:00
pepijn223 e3deff00ad feat(lingbot_va): RoboTwin eef-pose eval, single-file model, Hub checkpoints
Make the LingBot-VA port runnable on both LIBERO and RoboTwin and clean up the
package to LeRobot conventions.

- Consolidate all vendored Wan2.2 model code (transformer, attention, VAE helpers,
  flow-matching scheduler, grid utils, flex-attention) into a single
  modeling_lingbot_va.py; remove the separate wan_*/schedulers modules.
- Move the fixed action (un)normalization quantiles out of the config and into the
  post-processor (LIBERO 7-DoF + RoboTwin 16-d eef); remove the conversion script in
  favour of ready-to-use LeRobot-format checkpoints on the Hub.
- Fixes found via on-sim validation: undo LIBERO's 180-degree image flip
  (image_hflip), encode obs as a multi-frame streaming-VAE clip, reset the streaming
  VAE cache between episodes, run the transformer in config.dtype, lazy-load frozen
  VAE/UMT5 by subfolder with the text encoder on CPU.
- RoboTwin: add an end-effector-pose action mode to RoboTwinEnv (16-d per-arm
  xyz+quat+gripper deltas composed onto the initial eef pose, executed via CuRobo IK)
  and the robotwin_tshape latent layout (full-res head + half-res wrists via a second
  streaming VAE) with the upstream RoboTwin action quantiles + camera mapping.
- Predicted-video saving works for both benchmarks; docs + tests updated.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-06 15:20:51 +02:00
Pepijn 4dfa8cea65 feat(policies): add LingBot-VA autoregressive video-action world model
Port the LingBot-VA policy (Wan2.2 dual-stream video+action world model) into
LeRobot, following the EO-1 / VLA-JEPA conventions. Covers inference, checkpoint
conversion, and predicted-video saving (training is deferred to a follow-up PR).

- Vendored Wan transformer/attention/flex/VAE/scheduler modules (key names preserved
  for near-identity conversion); torch SDPA default, flashattn/flex lazy-guarded.
- LingBotVAConfig (registered "lingbot_va") + processor with fixed-quantile action
  unnormalization; full dual-stream sampling loop with CFG, two flow-matching
  schedulers and KV cache, mapped onto select_action with observed-keyframe feedback.
- convert_lingbot_va_checkpoints.py (libero/robotwin variants): bundles the ~5B
  transformer, lazy-pulls the frozen VAE+UMT5 from the source repo.
- Predicted-video plumbing in lerobot_eval (predicted_frames_callback; opt-in via
  --policy.save_predicted_video) and ConstantWithWarmupSchedulerConfig.
- pyproject: widen diffusers-dep to <0.37, add lingbot_va + imageio-dep extras,
  add lingbot_va and (missing) eo1 to `all`.
- Factory + policies/__init__ wiring, docs page + toctree, and tests.

Note: the LIBERO success-rate correctness gate must be validated on a CUDA GPU
with the converted checkpoint.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-05 16:28:19 +02:00
43 changed files with 3738 additions and 3857 deletions
+4
View File
@@ -22,6 +22,10 @@ outputs
rl
media
# Local virtualenvs (the image provides its own)
.venv
venv
# Logging
logs
+2
View File
@@ -67,6 +67,8 @@
title: VLA-JEPA
- local: eo1
title: EO-1
- local: lingbot_va
title: LingBot-VA
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+187
View File
@@ -0,0 +1,187 @@
# LingBot-VA
LingBot-VA is an **autoregressive video-action world-model policy** built on the **Wan2.2**
video-diffusion stack. It interleaves, in one autoregressive sequence, the prediction of
future **video latents** and **robot actions** ("VA" = Video-Action). The LeRobot
integration wires LingBot-VA into the standard training, evaluation and processor
interfaces.
## Model Overview
LingBot-VA is a **dual-stream "mixture-of-transformers"**: a video/latent stream
(`patch_embedding_mlp → blocks → proj_out`) and an action stream
(`action_embedder → blocks → action_proj_out`) share the same 30 transformer blocks and
text conditioning.
| Component | Class | Role |
| ------------------------ | ----------------------- | ----------------------------------------------------------- |
| DiT backbone (trainable) | `WanTransformer3DModel` | ~5B-param dual-stream transformer. |
| VAE (frozen) | `AutoencoderKLWan` | Wan2.2 VAE, `z_dim=48`. Lazy-pulled from the source repo. |
| Text encoder (frozen) | `UMT5EncoderModel` | UMT5-XXL, `d_model=4096`. Lazy-pulled from the source repo. |
At inference the policy runs an autoregressive loop per chunk: it denoises the video-latent
stream (CFG, ~20 steps) and the action stream (~50 steps) with two independent
flow-matching schedulers, maintaining a KV cache across chunks. Real observed keyframes are
fed back into the KV cache as the chunk is executed (closed-loop world modeling).
### What the LeRobot Integration Covers
- Standard `policy.type=lingbot_va` configuration through LeRobot.
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
- Autoregressive dual-stream inference behind the standard `select_action` interface
(single-environment eval, `--eval.batch_size=1`).
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
- Evaluation with `lerobot-eval` on LIBERO and RoboTwin.
- Training / fine-tuning via the dual-stream flow-matching loss (`policy.forward`), see below.
## Installation
1. Install LeRobot by following the [Installation Guide](./installation).
2. Install the LingBot-VA extra:
```bash
pip install -e ".[lingbot_va]"
```
## Checkpoints
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
| Variant | LeRobot checkpoint |
| ---------------------- | -------------------------------- |
| LIBERO-Long post-train | `lerobot/lingbot_va_libero_long` |
| RoboTwin post-train | `lerobot/lingbot_va_robotwin` |
| Pretrained base | `lerobot/lingbot_va_base` |
Only the trainable ~5B transformer is stored in the LeRobot
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are pulled from
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
transformer + VAE fit on a single 2432 GB GPU.
## Evaluation (LIBERO)
```bash
lerobot-eval \
--policy.path=lerobot/lingbot_va_libero_long \
--policy.device=cuda \
--env.type=libero --env.task=libero_10 \
--env.observation_height=128 --env.observation_width=128 \
--eval.n_episodes=50 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_libero
```
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
single-environment eval; use `--eval.batch_size=1`.
## Evaluation (RoboTwin)
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack. You can use the benchmark Docker image
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
executed via CuRobo IK.
```bash
lerobot-eval \
--policy.path=lerobot/lingbot_va_robotwin \
--policy.device=cuda \
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
--eval.n_episodes=10 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_robotwin
```
### Saving predicted (imagined) videos
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
latents and write `pred_episode_*.mp4` next to the env-rendered `eval_episode_*.mp4` videos.
The same flag works for the periodic eval during `lerobot-train`.
## Training / fine-tuning
`LingBotVAPolicy.forward(batch)` implements the dual-stream **flow-matching** loss
(`latent_loss + action_loss`, timestep-weighted, action-masked) from the paper: it VAE-encodes
the camera clips into video latents, UMT5-encodes the task, noises both streams, runs the
transformer's block-causal training pass and returns `(loss, metrics)`. Optimizer preset is AdamW
with a linear-warmup-then-constant schedule (matching upstream).
Requirements:
- The block-causal masks use PyTorch **flex-attention**, so build the policy with
`--policy.attn_mode=flex` for training (the default `torch` SDPA is inference-only).
- The full 5B DiT does not fit a single 2432 GB GPU under AdamW; fine-tune with **LoRA**
(`--policy.use_peft=true`) and/or optimizer offload. `get_optim_params` returns only the
trainable (e.g. adapter) parameters; the VAE + UMT5 text encoder stay frozen.
```bash
lerobot-train \
--policy.path=lerobot/lingbot_va_libero_long --policy.attn_mode=flex \
--policy.use_peft=true \
--dataset.repo_id=<your LeRobot-format dataset> \
--batch_size=1 --steps=... --output_dir=outputs/train/lingbot_va
```
The dataset must provide camera clips (a temporal window per camera, VAE-encoded to
`frame_chunk_size` latent frames) and `frame_chunk_size * action_per_frame` action steps per item.
## Data format (action channels & camera order)
LingBot-VA is an **end-effector (Cartesian) pose** policy, it predicts EEF poses + gripper, not
joint positions. Actions live in a fixed multi-embodiment **30-dim** layout; map your robot's
action dimensions into these channels and pad the rest with `0` (`used_action_channel_ids` selects
the channels a given checkpoint actually uses):
| channels | meaning |
| -------- | ----------------------------------------------------- |
| 06 | Left-arm end-effector pose |
| 713 | Right-arm end-effector pose |
| 1420 | Left-arm joints (unused by the released checkpoints) |
| 2127 | Right-arm joints (unused by the released checkpoints) |
| 28 | Left gripper |
| 29 | Right gripper |
- **LIBERO** uses channels `06`: a 6-DoF EEF delta (xyz + rotation) + gripper (single arm).
- **RoboTwin** uses channels `[06, 28, 713, 29]`: left EEF (xyz + quaternion) + left gripper +
right EEF + right gripper (16 dims). The env converts these poses to joint trajectories via
CuRobo IK — joints are never predicted.
Joint-space datasets (or a different EEF convention) must be remapped into this schema before
fine-tuning these checkpoints.
**Camera order is fixed and order-sensitive**, per-camera latents are concatenated spatially in
`obs_cam_keys` order, so the physical camera→slot mapping must match training:
| benchmark | `obs_cam_keys` (in order) | `camera_layout` |
| --------- | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
| LIBERO | `observation.images.image` (agentview / 3rd-person), `observation.images.image2` (eye-in-hand wrist) | `width_concat` (latents concatenated on width) |
| RoboTwin | `observation.images.head_camera`, `observation.images.left_camera`, `observation.images.right_camera` | `robotwin_tshape` (full-res head below, two half-res wrists on top) |
The first camera is the exterior/head view and the rest are wrist views.
## Inference Hyperparameters (LIBERO)
| Key | Value |
| -------------------------------------- | --------------------------------------------------------------------------------- |
| height × width | 128 × 128 |
| cameras | `observation.images.image` (agentview), `observation.images.image2` (eye-in-hand) |
| action channels used | 06 (7-DoF arm + gripper) |
| action_per_frame / frame_chunk_size | 4 / 4 |
| attn_window | 30 |
| video / action denoising steps | 20 / 50 |
| guidance_scale / action_guidance_scale | 5 / 1 |
| snr_shift / action_snr_shift | 5.0 / 0.05 |
These are the defaults of `LingBotVAConfig`; override any of them via `--policy.<name>=...`.
## Notes
- **Attention backend:** inference uses the `torch` SDPA backend (always available). The
`flashattn` and `flex` backends are optional; `flex` is only needed for training.
- **Model size:** the DiT is ~5B params and the frozen VAE+UMT5 add ~20 GB; inference needs
roughly 1824 GB of VRAM.
## License
LingBot-VA is released under Apache-2.0. See the
[upstream repository](https://github.com/Robbyant/lingbot-va).
-547
View File
@@ -1,547 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
This one file is both the orchestrator and the worker:
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
streaming dataset splits shards per node.
Scenarios (all single-frame / non-SARM):
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
sample latency. Results are written as JSON + CSV under ``--out_dir``.
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
"""
import argparse
import csv
import json
import os
import random
import statistics
import subprocess
import sys
import threading
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
# it; "main" pins the branch directly and skips that check.
MOLMO_REVISION = "main"
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
# ``cpus`` override the CLI defaults for that leg.
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
SCENARIOS = {
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
"mmap_local_maxworkers": {
"kind": "map",
"nodes": 1,
"mem": "128G",
"time": "01:00:00",
"num_workers": 30,
"cpus": 92,
},
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
}
def _tree_rss_bytes() -> int:
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
try:
children: dict[int, list[int]] = {}
for entry in os.listdir("/proc"):
if not entry.isdigit():
continue
try:
with open(f"/proc/{entry}/stat") as f:
ppid = int(f.read().split(") ", 1)[1].split()[1])
children.setdefault(ppid, []).append(int(entry))
except (OSError, ValueError, IndexError):
pass
total, stack = 0, [os.getpid()]
while stack:
cur = stack.pop()
try:
with open(f"/proc/{cur}/statm") as f:
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
except (OSError, ValueError, IndexError):
pass
stack.extend(children.get(cur, []))
return total
except OSError:
return 0
class PeakRSSSampler:
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
def __init__(self, interval_s: float = 0.5):
self.interval_s = interval_s
self.peak_bytes = 0
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
def _run(self) -> None:
while not self._stop.is_set():
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
self._stop.wait(self.interval_s)
def __enter__(self) -> "PeakRSSSampler":
self._thread.start()
return self
def __exit__(self, *exc) -> None:
self._stop.set()
self._thread.join(timeout=2)
def percentile(values: list[float], pct: float) -> float:
if not values:
return float("nan")
ordered = sorted(values)
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
return ordered[k]
class _TimedStreaming(StreamingLeRobotDataset):
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
library, to keep the dataset itself instrumentation-free."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fetch_s = 0.0
self.decode_s = 0.0
def __iter__(self):
self._in_flight_epoch = self._epoch
self._pipeline.set_epoch(self._in_flight_epoch)
self._epoch += 1
self.video_decoder_cache = self._make_video_decoder_cache()
iterator = iter(self._pipeline)
while True:
t0 = time.perf_counter()
try:
row = next(iterator)
except StopIteration:
return
t1 = time.perf_counter()
sample = self._finalize_sample(row)
t2 = time.perf_counter()
self.fetch_s += t1 - t0
self.decode_s += t2 - t1
yield sample
def select_node_episodes(
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
) -> list[int]:
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
episodes = list(range(meta.total_episodes))
from_idx = meta.episodes["dataset_from_index"]
to_idx = meta.episodes["dataset_to_index"]
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
if meta.video_keys:
file_columns = {
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
for key in meta.video_keys
}
else:
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
episode_file_ids = [
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
]
groups = group_episodes_by_files(episode_file_ids)
if len(groups) < num_partitions:
groups = [[i] for i in range(len(episodes))]
group_lengths = [sum(lengths[i] for i in g) for g in groups]
bins = partition_episodes(group_lengths, num_partitions)
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
return chosen[:cap] if cap and len(chosen) > cap else chosen
def build_dataset(scenario: str, args: argparse.Namespace):
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
if scenario.startswith("mmap_local"):
if not args.local_root:
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
return dataset, meta, True, {"loaded_episodes": len(episodes)}
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
dataset = _TimedStreaming(
MOLMO_REPO,
revision=MOLMO_REVISION,
data_files_root=data_files_root,
episode_pool_size=args.episode_pool_size,
max_buffer_input_shards=args.max_buffer_input_shards,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
validate_row_groups=False,
)
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
stage = fetch_s + decode_s
return {
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
}
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
it = iter(dataset)
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
next(it)
dataset.fetch_s = dataset.decode_s = 0.0
t0 = time.perf_counter()
for _ in range(n_probe):
next(it)
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
rng = random.Random(0)
n = len(dataset)
fetch_s = getitem_s = 0.0
warmed = measured = skipped = attempts = 0
while measured < n_probe and attempts < (warmup + n_probe) * 10:
attempts += 1
i = rng.randrange(n)
try:
t0 = time.perf_counter()
dataset.get_raw_item(i)
t1 = time.perf_counter()
dataset[i]
t2 = time.perf_counter()
except Exception:
skipped += 1
continue
if warmed < warmup:
warmed += 1
continue
fetch_s += t1 - t0
getitem_s += t2 - t1
measured += 1
if skipped:
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
device = torch.device(args.device)
dataset, meta, is_map_style, info = build_dataset(scenario, args)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
pin_memory=device.type == "cuda",
drop_last=True,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
persistent_workers=args.num_workers > 0,
)
sample_latencies_ms: list[float] = []
episodes_per_batch: list[int] = []
samples = 0
first_batch_latency_s = None
steady_start = None
t_start = time.perf_counter()
t_prev = t_start
with PeakRSSSampler() as rss:
for i, batch in enumerate(loader):
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i == args.warmup_batches:
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
samples += args.batch_size
ep = batch.get("episode_index")
if torch.is_tensor(ep):
episodes_per_batch.append(int(torch.unique(ep).numel()))
t_prev = now
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
if steady_start is not None and (now - steady_start) >= args.duration_s:
break
if i + 1 >= args.num_batches:
break
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
now = time.perf_counter()
elapsed = now - t_start
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
if samples == 0:
raise SystemExit(
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
"(inspect worker logs; try --num_workers 0 to surface the exception)."
)
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
# decodes video in the main process, which must stay decode-clean until the workers have forked
# (decoding before fork corrupts the workers' torchcodec state).
del loader
if is_map_style:
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
else:
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
num_cameras = len(meta.video_keys)
results = {
"scenario": scenario,
"rank": rank,
"world_size": world_size,
"loader": "map_style" if is_map_style else "streaming",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"episode_pool_size": None if is_map_style else args.episode_pool_size,
"max_buffer_input_shards": None
if is_map_style
else (args.max_buffer_input_shards or args.episode_pool_size),
**info,
"num_cameras": num_cameras,
"image_shape": image_shape,
"fps": meta.fps,
"peak_rss_gb": peak_rss_gb,
"samples_measured": samples,
"steady_window_s": round(steady_elapsed_s, 2),
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
if steady_elapsed_s
else 0.0,
**fetch_decode,
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
if episodes_per_batch
else None,
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
if sample_latencies_ms
else None,
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
"total_time_s": round(elapsed, 2),
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(flat))
writer.writeheader()
writer.writerow(flat)
print(json.dumps(results, indent=2), flush=True)
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
def submit_chain(args: argparse.Namespace) -> None:
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
"""
this_file = Path(__file__).resolve()
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
logs = Path(repo_dir) / "logs"
logs.mkdir(exist_ok=True)
run = f"conda run --no-capture-output -n {args.conda_env} python"
common = (
f"--batch_size {args.batch_size} "
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
)
if args.max_buffer_input_shards is not None:
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
if args.local_root:
common += f" --local_root {args.local_root}"
env_prefix = "export TOKENIZERS_PARALLELISM=false"
sched = []
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
if os.environ.get(env):
sched.append(f"{opt}={os.environ[env]}")
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
prev = ""
for scenario in selected:
cfg = SCENARIOS[scenario]
nw = cfg.get("num_workers", args.num_workers)
cpus = cfg.get("cpus", nw + 4)
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
if cfg["nodes"] > 1:
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
body = f"srun --export=ALL bash -c '{inner}'"
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
else:
body = f"cd {repo_dir} && {env_prefix} && {worker}"
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
cmd = [
"sbatch",
"--parsable",
f"--job-name=dlbench_{scenario}",
*node_flags,
f"--cpus-per-task={cpus}",
f"--mem={cfg['mem']}",
f"--time={cfg['time']}",
f"--output={logs}/%x-%j.out",
*sched,
]
if prev:
cmd.append(f"--dependency=afterany:{prev}")
cmd += ["--wrap", body]
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
prev = jid
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument(
"--scenario",
choices=list(SCENARIOS),
default=None,
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
)
p.add_argument(
"--scenarios",
type=str,
default=None,
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
)
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
p.add_argument(
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
)
p.add_argument("--partition_index", type=int, default=0)
p.add_argument(
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--num_workers", type=int, default=8)
p.add_argument("--prefetch_factor", type=int, default=2)
p.add_argument(
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
)
p.add_argument(
"--max_buffer_input_shards",
type=int,
default=None,
help="Concurrently-live random episodes feeding the pool after reshard() "
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
)
p.add_argument(
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
)
p.add_argument(
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
)
p.add_argument(
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
)
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
p.add_argument(
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
)
p.add_argument(
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
)
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
return p.parse_args()
def main() -> None:
args = parse_args()
if args.scenario is None:
if torch.cuda.is_available():
print(
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
file=sys.stderr,
)
submit_chain(args)
else:
run_scenario(args.scenario, args)
if __name__ == "__main__":
main()
+11 -15
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
"datasets>=4.7.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -146,7 +146,8 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.37.0"]
imageio-dep = ["imageio[ffmpeg]>=2.34.0,<3.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
pyserial-dep = ["pyserial>=3.5,<4.0"]
@@ -216,8 +217,9 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -231,9 +233,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
@@ -284,6 +286,7 @@ all = [
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[lingbot_va]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",
@@ -333,16 +336,6 @@ explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
@@ -385,6 +378,9 @@ ignore = [
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
# Vendored Wan2.2 / LingBot-VA model code uses tensor-dimension names (B, F, H, W) and `F` for
# torch.nn.functional.
"src/lerobot/policies/lingbot_va/**" = ["N803", "N806", "N812", "SIM102"]
[tool.ruff.lint.isort]
combine-as-imports = true
-51
View File
@@ -1,51 +0,0 @@
#!/usr/bin/env python
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from lerobot.datasets.byte_index_builder import (
build_byte_index_tables,
load_existing_file_ids,
write_byte_index,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
parser.add_argument("--repo-id", required=True)
parser.add_argument("--revision", default=None)
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
parser.add_argument("--no-keyframes", action="store_true")
args = parser.parse_args()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
output = args.output
existing = load_existing_file_ids(output)
if existing:
logger.info("resuming: %s files already indexed", len(existing))
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
meta,
args.data_root,
include_keyframes=not args.no_keyframes,
workers=args.workers,
existing_files=existing,
max_episodes=args.max_episodes,
)
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
logger.info("wrote byte index to %s", output)
if __name__ == "__main__":
main()
-4
View File
@@ -39,10 +39,6 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
# the pool holds tabular rows only. Ignored when streaming is False.
streaming_episode_pool_size: int = 1024
def __post_init__(self) -> None:
if self.episodes is not None:
-228
View File
@@ -1,228 +0,0 @@
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
from .mp4_episode_slice import episode_custom_frame_mappings_json
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class EpisodeSliceLookup:
global_episode_id: int
file_id: int
mdat_offset: int
mdat_length: int
frame_count: int
first_pts: float
last_pts: float
avg_fps: float
@property
def fetch_bytes(self) -> int:
return self.mdat_length
@dataclass(frozen=True)
class FileLookup:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
class EpisodeByteIndex:
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
def __init__(
self,
index_dir: str | Path | None,
*,
video_keys: list[str],
num_episodes: int,
mmap: bool = True,
files_table: pa.Table | None = None,
episodes_table: pa.Table | None = None,
mp4_by_rel: dict[str, Any] | None = None,
):
self.index_dir = Path(index_dir) if index_dir is not None else None
self.video_keys = list(video_keys)
self.num_episodes = num_episodes
self.num_cameras = len(video_keys)
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
self._mp4_by_rel = mp4_by_rel
self._frame_mappings_by_gid: dict[int, bytes] = {}
t0 = time.perf_counter()
if files_table is not None and episodes_table is not None:
files_tbl, episodes_tbl = files_table, episodes_table
else:
if self.index_dir is None:
raise ValueError("index_dir or in-memory tables required")
files_path = self.index_dir / FILES_NAME
episodes_path = self.index_dir / EPISODES_NAME
if not files_path.exists() or not episodes_path.exists():
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
files_tbl = pq.read_table(files_path, memory_map=mmap)
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
self.build_time_s = time.perf_counter() - t0
self.load_time_s = self.build_time_s
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
def col(tbl, name: str):
array = tbl.column(name).combine_chunks()
if pa.types.is_boolean(array.type):
return array.to_numpy(zero_copy_only=False)
return array.to_numpy()
self.file_id = col(files_tbl, "file_id")
self.file_path = files_tbl.column("file_path").to_pylist()
self.file_size = col(files_tbl, "file_size")
self.moov_offset = col(files_tbl, "moov_offset")
self.moov_length = col(files_tbl, "moov_length")
self.header_length = col(files_tbl, "header_length")
self.faststart = col(files_tbl, "faststart")
self.file_avg_fps = col(files_tbl, "avg_fps")
self.codec = files_tbl.column("codec").to_pylist()
ep = episodes_tbl
n = len(ep)
gid = col(ep, "global_episode_id")
order = np.argsort(gid)
self._global_episode_id = gid[order]
self._episode_index = col(ep, "episode_index")[order]
self._camera_index = col(ep, "camera_index")[order]
self._file_id = col(ep, "file_id")[order]
self._mdat_offset = col(ep, "mdat_offset")[order]
self._mdat_length = col(ep, "mdat_length")[order]
self._frame_count = col(ep, "frame_count")[order]
self._first_pts = col(ep, "first_pts")[order]
self._last_pts = col(ep, "last_pts")[order]
expected = self.num_episodes * self.num_cameras
if n != expected:
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
if self.index_dir is not None:
keyframes_path = self.index_dir / KEYFRAMES_NAME
if keyframes_path.exists():
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
self._keyframes_rows = len(kf_tbl)
else:
self._keyframes_rows = 0
else:
self._keyframes_rows = 0
self.resident_bytes = int(
self._global_episode_id.nbytes
+ self._file_id.nbytes
+ self._mdat_offset.nbytes
+ self._mdat_length.nbytes
+ self.file_size.nbytes
)
@classmethod
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
@classmethod
def from_memory_build(
cls,
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = True,
) -> EpisodeByteIndex:
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
from .byte_index_builder import build_byte_index_in_memory
return build_byte_index_in_memory(
meta,
data_root,
workers=workers,
max_episodes=max_episodes,
include_frame_mappings_cache=include_frame_mappings_cache,
)
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
row = int(gid)
if row < 0 or row >= len(self._global_episode_id):
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
file_id = int(self._file_id[row])
return EpisodeSliceLookup(
global_episode_id=gid,
file_id=file_id,
mdat_offset=int(self._mdat_offset[row]),
mdat_length=int(self._mdat_length[row]),
frame_count=int(self._frame_count[row]),
first_pts=float(self._first_pts[row]),
last_pts=float(self._last_pts[row]),
avg_fps=float(self.file_avg_fps[file_id]),
)
def file_lookup(self, file_id: int) -> FileLookup:
return FileLookup(
file_id=file_id,
file_path=self.file_path[file_id],
file_size=int(self.file_size[file_id]),
moov_offset=int(self.moov_offset[file_id]),
moov_length=int(self.moov_length[file_id]),
header_length=int(self.header_length[file_id]),
faststart=bool(self.faststart[file_id]),
avg_fps=float(self.file_avg_fps[file_id]),
codec=self.codec[file_id],
)
def header_byte_range(self, file_id: int) -> tuple[int, int]:
length = int(self.header_length[file_id])
return 0, max(0, length - 1)
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
cached = self._frame_mappings_by_gid.get(gid)
if cached is not None:
return cached
if self._mp4_by_rel is None:
return None
lookup = self.lookup(episode_index, camera_key)
rel = self.file_path[lookup.file_id]
mp4_index = self._mp4_by_rel.get(rel)
if mp4_index is None:
return None
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
self._frame_mappings_by_gid[gid] = payload
return payload
def stats_dict(self) -> dict[str, float | int]:
return {
"load_time_s": self.load_time_s,
"build_time_s": self.build_time_s,
"resident_bytes": self.resident_bytes,
"frame_mappings_cached": len(self._frame_mappings_by_gid),
"mp4_indices_cached": len(self._mp4_by_rel or {}),
"num_files": len(self.file_path),
"num_episode_rows": len(self._global_episode_id),
}
-281
View File
@@ -1,281 +0,0 @@
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
from __future__ import annotations
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import fsspec
import pyarrow as pa
import pyarrow.parquet as pq
from .mp4_episode_slice import (
HEADER_PROBE_BYTES,
MAX_HEADER_PROBE_BYTES,
average_fps_from_index,
episode_keyframes,
parse_mp4_file_layout,
parse_mp4_index,
)
logger = logging.getLogger(__name__)
BYTE_INDEX_DIR = "meta/byte_index"
FILES_NAME = "files.parquet"
EPISODES_NAME = "episodes.parquet"
KEYFRAMES_NAME = "keyframes.parquet"
@dataclass
class IndexedFile:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
def fetch_header_bytes(path: str, file_size: int) -> bytes:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
probe = HEADER_PROBE_BYTES
while True:
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
header = f.read(min(probe, file_size))
try:
parse_mp4_file_layout(header, file_size)
return header
except ValueError as exc:
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
raise
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
file_size = fs.info(path)["size"]
header = fetch_header_bytes(path, file_size)
layout = parse_mp4_file_layout(header, file_size)
if not layout.faststart:
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
mp4_index = parse_mp4_index(header, file_size)
indexed = IndexedFile(
file_id=-1,
file_path=rel_path or path,
file_size=file_size,
moov_offset=layout.moov_offset,
moov_length=layout.moov_length,
header_length=layout.header_end,
faststart=layout.faststart,
avg_fps=average_fps_from_index(mp4_index),
codec=layout.codec,
)
return indexed, mp4_index
def build_byte_index_tables(
meta,
data_root: str,
*,
file_paths: list[str] | None = None,
include_keyframes: bool = True,
workers: int = 8,
existing_files: dict[str, int] | None = None,
max_episodes: int | None = None,
return_mp4_indices: bool = False,
complete_files_table: bool = False,
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
"""Build files/episodes/(optional keyframes) Arrow tables."""
video_keys = list(meta.video_keys)
n_cams = len(video_keys)
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
rel_paths: set[str] = set()
for ep_idx in range(num_episodes):
for cam in video_keys:
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
if file_paths is None:
file_paths = list(path_by_rel.values())
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
existing_files = existing_files or {}
file_meta_by_rel: dict[str, dict[str, Any]] = {}
mp4_by_rel: dict[str, Any] = {}
next_file_id = max(existing_files.values(), default=-1) + 1
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
if to_index:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
}
for fut in as_completed(futures):
rel = futures[fut]
indexed, mp4_index = fut.result()
indexed.file_id = next_file_id
mp4_by_rel[rel] = mp4_index
file_meta_by_rel[rel] = {
"file_id": indexed.file_id,
"file_path": rel,
"file_size": indexed.file_size,
"moov_offset": indexed.moov_offset,
"moov_length": indexed.moov_length,
"header_length": indexed.header_length,
"faststart": indexed.faststart,
"avg_fps": indexed.avg_fps,
"codec": indexed.codec,
}
existing_files[rel] = indexed.file_id
next_file_id += 1
missing_rels = {
str(meta.get_video_file_path(ep, cam))
for ep in range(num_episodes)
for cam in video_keys
} - set(mp4_by_rel.keys())
if missing_rels:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
for rel in missing_rels
if rel not in mp4_by_rel
}
for fut in as_completed(futures):
rel = futures[fut]
_, mp4_index = fut.result()
mp4_by_rel[rel] = mp4_index
episode_rows: list[dict[str, Any]] = []
keyframe_rows: list[dict[str, Any]] = []
for ep_idx in range(num_episodes):
for cam in video_keys:
rel = str(meta.get_video_file_path(ep_idx, cam))
path = f"{data_root.rstrip('/')}/{rel}"
if rel not in existing_files:
raise KeyError(f"file not indexed: {rel}")
mp4_index = mp4_by_rel[rel]
ep = meta.episodes[ep_idx]
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
span = mp4_index.episode_byte_span(from_ts, to_ts)
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
mdat_length = span.slice_hi - span.slice_lo + 1
episode_rows.append(
{
"global_episode_id": global_episode_id,
"episode_index": ep_idx,
"camera_key": cam,
"camera_index": cam_to_idx[cam],
"file_id": existing_files[rel],
"mdat_offset": span.slice_lo,
"mdat_length": mdat_length,
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
"first_pts": from_ts,
"last_pts": to_ts,
}
)
if include_keyframes:
timescale = mp4_index.timescale
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
keyframe_rows.append(
{
"global_episode_id": global_episode_id,
"pts": int(round(pts_s * timescale)),
"byte_offset": byte_off,
}
)
referenced_rels = {
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
}
if complete_files_table:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
elif to_index:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
else:
files_table = None
episodes_table = pa.Table.from_pylist(episode_rows)
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
if return_mp4_indices:
return files_table, episodes_table, keyframes_table, mp4_by_rel
return files_table, episodes_table, keyframes_table
def build_byte_index_in_memory(
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = False,
):
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
from .byte_index import EpisodeByteIndex
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
meta,
data_root,
include_keyframes=False,
workers=workers,
max_episodes=max_episodes,
return_mp4_indices=True,
complete_files_table=True,
)
index = EpisodeByteIndex(
None,
video_keys=list(meta.video_keys),
num_episodes=num_episodes,
files_table=files_tbl,
episodes_table=episodes_tbl,
mp4_by_rel=mp4_by_rel,
)
if include_frame_mappings_cache:
for ep_idx in range(num_episodes):
for cam in meta.video_keys:
index.custom_frame_mappings(ep_idx, cam)
return index
def write_byte_index(
output_dir: Path,
files_table: pa.Table | None,
episodes_table: pa.Table,
keyframes_table: pa.Table | None = None,
*,
merge_existing: bool = True,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
files_path = output_dir / FILES_NAME
episodes_path = output_dir / EPISODES_NAME
keyframes_path = output_dir / KEYFRAMES_NAME
if merge_existing and files_path.exists() and files_table is not None:
prev = pq.read_table(files_path)
files_table = pa.concat_tables([prev, files_table])
if files_table is not None:
pq.write_table(files_table, files_path)
pq.write_table(episodes_table, episodes_path)
if keyframes_table is not None:
if merge_existing and keyframes_path.exists():
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
pq.write_table(keyframes_table, keyframes_path)
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
files_path = index_dir / FILES_NAME
if not files_path.exists():
return {}
table = pq.read_table(files_path, columns=["file_id", "file_path"])
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
+2 -11
View File
@@ -945,17 +945,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
ep_dataset = embed_images(ep_dataset)
table = ep_dataset.with_format("arrow")[:]
# Emit several row groups with a page index instead of one giant row group. A single row group forces
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
# groups bounds per-shard memory while keeping groups large enough to scan
# efficiently; the page index lets readers skip to the pages they need.
target_row_group_bytes = 32 * 1024 * 1024
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
)
writer.write_table(table, row_group_size=row_group_size)
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
writer.close()
-263
View File
@@ -1,263 +0,0 @@
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
from __future__ import annotations
import logging
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any
import fsspec
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
from .mp4_episode_slice import SparseMp4Reader
from .torchcodec_utils import open_video_decoder
logger = logging.getLogger(__name__)
@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
bytes_fetched: int = 0
full_file_fallbacks: int = 0
prefetch_submitted: int = 0
prefetch_waits: int = 0
mdat_slices: int = 0
prefix_fetches: int = 0
fetch_to_buffer_s: float = 0.0
buffer_to_decoder_s: float = 0.0
buffer_hit_decoder_s: float = 0.0
decode_frame_s: float = 0.0
decode_frames: int = 0
def merge(self, other: CacheStats) -> None:
for name in self.__dataclass_fields__:
setattr(self, name, getattr(self, name) + getattr(other, name))
def stats_dict(self) -> dict[str, int | float]:
avg_miss = self.bytes_fetched / max(1, self.misses)
return {
"byte_cache_hits": self.hits,
"byte_cache_misses": self.misses,
"byte_cache_bytes_fetched": self.bytes_fetched,
"byte_cache_bytes_per_miss": avg_miss,
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
"byte_cache_prefetch_submitted": self.prefetch_submitted,
"byte_cache_prefetch_waits": self.prefetch_waits,
"byte_cache_mdat_slices": self.mdat_slices,
"byte_cache_prefix_fetches": self.prefix_fetches,
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
}
@dataclass
class _EpisodeEntry:
decoders: dict[str, Any] = field(default_factory=dict)
ready: threading.Event = field(default_factory=threading.Event)
error: Exception | None = None
class RangeFetcher:
"""Sequential byte-range GETs via fsspec."""
def __init__(self, path: str):
self.path = path
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
def fetch(self, lo: int, hi: int) -> bytes:
if hi < lo:
return b""
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
f.seek(lo)
return f.read(hi - lo + 1)
class EpisodeByteCache:
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
def __init__(
self,
byte_index: EpisodeByteIndex,
max_bytes: int,
*,
data_root: str,
max_prefetch_workers: int = 4,
):
if max_bytes <= 0:
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
self.byte_index = byte_index
self.max_bytes = max_bytes
self.data_root = data_root.rstrip("/")
self._bytes_used = 0
self._lock = threading.Lock()
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
self._header_cache: dict[int, bytes] = {}
self._fetcher_cache: dict[int, RangeFetcher] = {}
self._episodes: dict[int, _EpisodeEntry] = {}
self._stats = CacheStats()
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
self._futures: dict[int, Future] = {}
@property
def stats(self) -> CacheStats:
with self._lock:
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
def submit_prefetch(self, ep_idx: int) -> None:
with self._lock:
if ep_idx in self._episodes or ep_idx in self._futures:
return
self._stats.prefetch_submitted += 1
fut = self._executor.submit(self._prefetch_episode, ep_idx)
self._futures[ep_idx] = fut
def ensure_ready(self, ep_idx: int) -> None:
with self._lock:
fut = self._futures.pop(ep_idx, None)
if fut is not None:
with self._lock:
self._stats.prefetch_waits += 1
fut.result()
entry = self._episodes.get(ep_idx)
if entry is None:
raise KeyError(f"episode {ep_idx} not prefetched")
if entry.error is not None:
raise entry.error
entry.ready.wait()
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
entry = self._episodes[ep_idx]
if entry.error is not None:
raise entry.error
entry.ready.wait()
return entry.decoders[video_key]
def close(self) -> None:
self._executor.shutdown(wait=False, cancel_futures=True)
def _prefetch_episode(self, ep_idx: int) -> None:
entry = _EpisodeEntry()
self._episodes[ep_idx] = entry
try:
for cam in self.byte_index.video_keys:
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
except Exception as exc:
entry.error = exc
finally:
entry.ready.set()
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
key = (ep_idx, cam)
with self._lock:
cached = self._cache.get(key)
if cached is not None:
self._cache.move_to_end(key)
self._stats.hits += 1
payload, _ = cached
t0 = time.perf_counter()
dec = self._decoder_from_payload(payload, ep_idx, cam)
with self._lock:
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
return dec
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
with self._lock:
self._stats.misses += 1
if payload_bytes > self.MAX_BYTES_PER_MISS:
logger.warning(
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
payload_bytes / 1e6,
ep_idx,
cam,
)
self._evict_until(payload_bytes)
self._cache[key] = (payload, payload_bytes)
self._bytes_used += payload_bytes
return dec
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
lookup = self.byte_index.lookup(ep_idx, cam)
file_info = self.byte_index.file_lookup(lookup.file_id)
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
t_fetch = time.perf_counter()
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
lo = lookup.mdat_offset
hi = lo + lookup.mdat_length - 1
mdat = fetcher.fetch(lo, hi)
fetch_s = time.perf_counter() - t_fetch
nbytes = len(header) + len(mdat)
with self._lock:
self._stats.bytes_fetched += nbytes
self._stats.mdat_slices += 1
self._stats.fetch_to_buffer_s += fetch_s
def lazy_fetch(pos: int, end: int) -> bytes:
data = fetcher.fetch(pos, end)
with self._lock:
self._stats.bytes_fetched += len(data)
return data
reader = SparseMp4Reader(
file_size=file_info.file_size,
header=header,
mdat_lo=lo,
mdat_bytes=mdat,
lazy_fetch=lazy_fetch,
)
t_init = time.perf_counter()
dec = self._decoder_from_payload(reader, ep_idx, cam)
self._validate_decoder(dec, lookup)
init_s = time.perf_counter() - t_init
with self._lock:
self._stats.buffer_to_decoder_s += init_s
self._rewind_payload(reader)
return reader, nbytes, dec
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
if file_id not in self._fetcher_cache:
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
self._fetcher_cache[file_id] = RangeFetcher(path)
return self._fetcher_cache[file_id]
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
if file_id in self._header_cache:
return self._header_cache[file_id]
hi = max(0, header_length - 1)
header = fetcher.fetch(0, hi)
with self._lock:
self._header_cache[file_id] = header
self._stats.bytes_fetched += len(header)
return header
def _decoder_from_payload(
self, payload: SparseMp4Reader, ep_idx: int, cam: str
) -> Any:
payload.seek(0)
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
return open_video_decoder(payload, frame_mappings=mappings)
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
duration = max(0.01, end - begin)
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
dec.get_frames_played_at([ts]).data
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
payload.seek(0)
def _evict_until(self, need: int) -> None:
while self._bytes_used + need > self.max_bytes and self._cache:
_, (_, size) = self._cache.popitem(last=False)
self._bytes_used -= size
+1 -1
View File
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
-555
View File
@@ -1,555 +0,0 @@
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
For streaming we fetch only the file header plus the episode's contiguous mdat span
instead of the ``0..episode_end`` prefix.
"""
from __future__ import annotations
import io
import struct
import threading
from dataclasses import dataclass, field
from typing import Callable
KEYFRAME_PAD_S = 0.1
HEADER_PROBE_BYTES = 4 * 1024 * 1024
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
@dataclass
class Mp4FileLayout:
file_size: int
moov_offset: int
moov_length: int
header_end: int
mdat_offset: int
mdat_size: int
faststart: bool
codec: str
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
boxes = list(_iter_boxes(header_bytes))
moov_offset = mdat_offset = -1
moov_length = mdat_size = 0
for off, size, typ, _ in boxes:
if typ == b"moov" and moov_offset < 0:
moov_offset, moov_length = off, size
if typ == b"mdat" and mdat_offset < 0:
mdat_offset, mdat_size = off, size
if moov_offset < 0:
raise ValueError("moov box not found in header probe")
if mdat_offset < 0:
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
faststart = moov_offset < mdat_offset
header_end = mdat_offset
codec = _parse_video_codec(header_bytes)
return Mp4FileLayout(
file_size=file_size,
moov_offset=moov_offset,
moov_length=moov_length,
header_end=header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
faststart=faststart,
codec=codec,
)
def _parse_video_codec(header_bytes: bytes) -> str:
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
return "unknown"
trak = _find_video_trak(moov)
if trak is None:
return "unknown"
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
if stsd is None or len(stsd) < 12:
return "unknown"
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
if len(stsd) >= 12:
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
return "unknown"
def average_fps_from_index(index: Mp4VideoIndex) -> float:
index.ensure_tables()
if index.num_samples < 2:
return 30.0
duration = index.sample_pts(index.num_samples - 1)
if duration <= 0:
return 30.0
return index.num_samples / duration
def episode_custom_frame_mappings_json(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> bytes:
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
import json
index.ensure_tables()
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
hi_idx = min(hi_idx, index.num_samples - 1)
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
sync = set(index.sync_samples)
timescale = index.timescale
# stts deltas for duration per sample (expand stts entries to per-sample delta)
sample_deltas: list[int] = []
for count, delta in index.stts:
sample_deltas.extend([delta] * count)
while len(sample_deltas) < index.num_samples:
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
frames = []
for idx in range(lo_idx, hi_idx + 1):
frames.append(
{
"pts": int(round(index._pts[idx] * timescale)),
"duration": int(sample_deltas[idx]),
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
}
)
return json.dumps({"frames": frames}).encode()
def episode_keyframes(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> list[tuple[float, int]]:
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
index.ensure_tables()
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
if not index.sync_samples:
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
out: list[tuple[float, int]] = []
for sync_one_based in index.sync_samples:
idx = sync_one_based - 1
if lo_idx <= idx <= hi_idx:
out.append((index.sample_pts(idx), index.sample_offset(idx)))
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
@dataclass
class EpisodeByteSpan:
"""Absolute file byte ranges to fetch for one episode."""
file_size: int
header_end: int
slice_lo: int
slice_hi: int
@property
def header_bytes(self) -> tuple[int, int]:
return 0, self.header_end - 1
@property
def mdat_bytes(self) -> tuple[int, int]:
return self.slice_lo, self.slice_hi
@property
def total_fetch_bytes(self) -> int:
header = self.header_end
mdat = self.slice_hi - self.slice_lo + 1
return header + mdat
@dataclass
class Mp4VideoIndex:
file_size: int
header_end: int
mdat_offset: int
mdat_size: int
timescale: int
stts: list[tuple[int, int]]
stsz: list[int]
stsc: list[tuple[int, int, int]]
stco: list[int]
sync_samples: list[int]
_pts: list[float] = field(default_factory=list, repr=False)
_offsets: list[int] = field(default_factory=list, repr=False)
def ensure_tables(self) -> None:
if self._pts:
return
self._pts = _pts_from_stts(self.stts, self.timescale)
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
@property
def num_samples(self) -> int:
return len(self.stsz)
def sample_pts(self, index: int) -> float:
self.ensure_tables()
return self._pts[index]
def sample_offset(self, index: int) -> int:
self.ensure_tables()
index = max(0, min(index, len(self._offsets) - 1))
return self._offsets[index]
def sample_end(self, index: int) -> int:
return self.sample_offset(index) + self.stsz[index]
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
self.ensure_tables()
n = self.num_samples
if n == 0:
raise ValueError("MP4 has no video samples")
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
lo_ts = max(0.0, from_ts - pad)
hi_ts = to_ts + pad
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
hi_idx = min(hi_idx, n - 1)
lo_idx = min(lo_idx, n - 1)
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
slice_lo = self.sample_offset(lo_idx)
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
return EpisodeByteSpan(
file_size=self.file_size,
header_end=self.header_end,
slice_lo=slice_lo,
slice_hi=min(slice_hi, self.file_size - 1),
)
class SparseMp4Reader(io.BufferedIOBase):
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
def __init__(
self,
file_size: int,
header: bytes,
mdat_lo: int,
mdat_bytes: bytes,
lazy_fetch: Callable[[int, int], bytes] | None = None,
):
self._size = file_size
self._header = header
self._mdat_lo = mdat_lo
self._mdat_hi = mdat_lo + len(mdat_bytes)
self._mdat = mdat_bytes
self._lazy_fetch = lazy_fetch
self._pos = 0
self._lock = threading.Lock()
def readable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
self._pos = offset
elif whence == io.SEEK_CUR:
self._pos += offset
elif whence == io.SEEK_END:
self._pos = self._size + offset
else:
raise ValueError(f"invalid whence: {whence}")
self._pos = max(0, min(self._pos, self._size))
return self._pos
def read(self, size: int = -1) -> bytes:
if size < 0:
size = self._size - self._pos
if size <= 0:
return b""
out = bytearray()
remaining = size
pos = self._pos
while remaining > 0 and pos < self._size:
chunk = self._read_at(pos, remaining)
if not chunk:
break
out.extend(chunk)
pos += len(chunk)
remaining -= len(chunk)
self._pos = pos
return bytes(out)
def _read_at(self, pos: int, n: int) -> bytes:
header_len = len(self._header)
if pos < header_len:
end = min(pos + n, header_len)
return self._header[pos:end]
if self._mdat_lo <= pos < self._mdat_hi:
end = min(pos + n, self._mdat_hi)
off = pos - self._mdat_lo
return self._mdat[off : off + (end - pos)]
if self._lazy_fetch is not None:
with self._lock:
end = min(pos + n, self._size)
return self._lazy_fetch(pos, end - 1)
return b"\x00" * min(n, self._size - pos)
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
"""Parse moov sample tables from the file header (faststart layout)."""
layout = parse_mp4_file_layout(header_bytes, file_size)
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
raise ValueError("moov box not found in MP4 header probe")
trak = _find_video_trak(moov)
if trak is None:
raise ValueError("video trak not found in moov")
mdhd = _find_box_payload(trak, b"mdhd")
if mdhd is None:
raise ValueError("mdhd not found")
timescale = _parse_mdhd_timescale(mdhd)
stbl = _find_box_payload(trak, b"stbl")
if stbl is None:
raise ValueError("stbl not found")
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
stco_payload = _find_box_payload(stbl, b"stco")
co64_payload = _find_box_payload(stbl, b"co64")
if stco_payload is not None:
stco = _parse_stco(stco_payload)
elif co64_payload is not None:
stco = _parse_co64(co64_payload)
else:
raise ValueError("stco/co64 not found")
stss_payload = _find_box_payload(stbl, b"stss")
sync_samples = _parse_stss(stss_payload) if stss_payload else []
return Mp4VideoIndex(
file_size=file_size,
header_end=layout.header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
timescale=timescale,
stts=stts,
stsz=stsz,
stsc=stsc,
stco=stco,
sync_samples=sync_samples,
)
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
if offset + 8 > len(data):
return None
size, typ = struct.unpack_from(">I4s", data, offset)
header = 8
if size == 1:
if offset + 16 > len(data):
return None
size = struct.unpack_from(">Q", data, offset + 8)[0]
header = 16
elif size == 0:
size = len(data) - offset
return size, typ, header
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
end = end if end is not None else len(data)
off = start
while off + 8 <= end:
hdr = _box_header(data, off)
if hdr is None or hdr[0] < hdr[2]:
break
size, typ, header = hdr
yield off, size, typ, data[off + header : off + size]
off += size
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(data):
if typ == target:
return payload
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
found = _find_box_payload(payload, target)
if found is not None:
return found
return None
def _find_video_trak(moov: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(moov):
if typ != b"trak":
continue
hdlr = _find_box_payload(payload, b"hdlr")
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
return payload
return None
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
for off, size, typ, _ in _iter_boxes(header_bytes):
if typ == b"mdat":
return off, size
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
def _parse_mdhd_timescale(mdhd: bytes) -> int:
version = mdhd[0]
if version == 0:
return struct.unpack_from(">I", mdhd, 12)[0]
return struct.unpack_from(">I", mdhd, 20)[0]
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
if stts is None:
raise ValueError("stts missing")
count = struct.unpack_from(">I", stts, 4)[0]
out = []
off = 8
for _ in range(count):
sample_count, delta = struct.unpack_from(">II", stts, off)
out.append((sample_count, delta))
off += 8
return out
def _parse_stsz(stsz: bytes | None) -> list[int]:
if stsz is None:
raise ValueError("stsz missing")
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
if sample_size != 0:
return [sample_size] * sample_count
off = 12
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
if stsc is None:
raise ValueError("stsc missing")
count = struct.unpack_from(">I", stsc, 4)[0]
out = []
off = 8
for _ in range(count):
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
out.append((first_chunk, samples_per_chunk, sample_desc))
off += 12
return out
def _parse_stco(stco: bytes) -> list[int]:
count = struct.unpack_from(">I", stco, 4)[0]
return list(struct.unpack_from(f">{count}I", stco, 8))
def _parse_co64(co64: bytes) -> list[int]:
count = struct.unpack_from(">I", co64, 4)[0]
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
def _parse_stss(stss: bytes) -> list[int]:
count = struct.unpack_from(">I", stss, 4)[0]
return list(struct.unpack_from(f">{count}I", stss, 8))
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
pts: list[float] = []
t = 0
for count, delta in stts:
for _ in range(count):
pts.append(t / timescale)
t += delta
return pts
def _sample_byte_offsets(
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
) -> list[int]:
if not stsc:
stsc = [(1, len(stsz), 1)]
offsets: list[int] = []
chunk_idx = 0
sample_idx = 0
sc_idx = 0
num_chunks = len(stco)
while chunk_idx < num_chunks and sample_idx < len(stsz):
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
if sc_idx + 1 < len(stsc):
next_first = stsc[sc_idx + 1][0]
chunks_in_entry = next_first - first_chunk
else:
chunks_in_entry = num_chunks - chunk_idx
for _ in range(chunks_in_entry):
if chunk_idx >= num_chunks:
break
offset = stco[chunk_idx]
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
for _ in range(samples_per_chunk):
if sample_idx >= len(stsz):
break
offsets.append(offset)
offset += stsz[sample_idx]
sample_idx += 1
chunk_idx += 1
sc_idx += 1
if len(offsets) < len(stsz):
# Pad with last known offset progression for malformed stsc edge cases.
last = offsets[-1] if offsets else 0
while len(offsets) < len(stsz):
idx = len(offsets)
offsets.append(last)
last += stsz[idx]
return offsets
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] < ts:
lo = mid + 1
else:
hi = mid
return min(lo, len(pts) - 1)
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] <= ts:
lo = mid + 1
else:
hi = mid
return max(0, lo - 1)
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
if not sync_samples:
return max(0, sample_idx - 2)
# stss stores 1-based sample numbers
one_based = sample_idx + 1
prev = [s for s in sync_samples if s <= one_based]
if prev:
return prev[-1] - 1
return 0
+1 -7
View File
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
File diff suppressed because it is too large Load Diff
-49
View File
@@ -1,49 +0,0 @@
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
from __future__ import annotations
import json
from typing import Any
import torch
from torchcodec import FrameBatch, _core as core
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
data = json.loads(payload)
frames = data["frames"]
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
return pts, key, dur
class VideoDecoderLike:
"""Minimal VideoDecoder surface used by episode byte cache."""
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
self._decoder = decoder
(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
self._num_frames,
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
if frame_mappings is None:
decoder = core.create_from_file_like(source, "approximate")
core.add_video_stream(decoder)
return VideoDecoderLike(decoder)
mappings = frame_mappings_tensors(frame_mappings)
decoder = core.create_from_file_like(source, "custom_frame_mappings")
core.add_video_stream(decoder, custom_frame_mappings=mappings)
return VideoDecoderLike(decoder)
+3 -10
View File
@@ -273,11 +273,7 @@ class VideoDecoderCache:
self._cache.move_to_end(video_path)
return entry[0]
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
# mostly-sequential reads torchcodec issues.
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
@@ -326,7 +322,6 @@ def decode_video_frames_torchcodec(
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
return_uint8: bool = False,
episode_decoder: Any | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
@@ -348,10 +343,8 @@ def decode_video_frames_torchcodec(
if decoder_cache is None:
decoder_cache = _default_decoder_cache
if episode_decoder is not None:
decoder = episode_decoder
else:
decoder = decoder_cache.get_decoder(str(video_path))
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
loaded_ts = []
loaded_frames = []
+7 -1
View File
@@ -757,7 +757,7 @@ class RoboTwinEnvConfig(EnvConfig):
task: str = "beat_block_hammer" # single task or comma-separated list
fps: int = 25
episode_length: int = 300
episode_length: int = 1200
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
# Available cameras from RoboTwin's aloha-agilex embodiment: head_camera
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
# must equal what SAPIEN actually renders.
observation_height: int = 240
observation_width: int = 320
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
action_mode: str = "joint"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
)
def __post_init__(self):
if self.action_mode == "ee":
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
for cam in cam_list:
self.features[f"pixels/{cam}"] = PolicyFeature(
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
observation_height=self.observation_height,
observation_width=self.observation_width,
episode_length=self.episode_length,
action_mode=self.action_mode,
)
+154 -6
View File
@@ -17,6 +17,7 @@ from __future__ import annotations
import importlib
import logging
import os
from collections import defaultdict
from collections.abc import Callable, Sequence
from functools import partial
@@ -41,10 +42,117 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
"right_camera",
)
ACTION_DIM = 14 # 7 DOF × 2 arms
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
EEF_ACTION_DIM = 16
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
DEFAULT_EPISODE_LENGTH = 300
DEFAULT_EPISODE_LENGTH = 1200
OFFICIAL_INSTRUCTION_ENV = "LEROBOT_ROBOTWIN_OFFICIAL_INSTRUCTION"
OFFICIAL_INSTRUCTION_TYPE_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_TYPE"
OFFICIAL_INSTRUCTION_MAX_ENV = "LEROBOT_ROBOTWIN_INSTRUCTION_MAX"
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a single-arm predicted delta pose onto the initial pose.
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
"""
from scipy.spatial.transform import Rotation
new_r = Rotation.from_quat(new_pose[3:7])
init_r = Rotation.from_quat(init_pose[3:7])
out_rot = (init_r * new_r).as_quat()
out_trans = new_pose[:3] + init_pose[:3]
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
out = np.concatenate([left, right])
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
return out
def _env_flag(name: str, default: bool = False) -> bool:
raw = os.environ.get(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _arm_for_block(block: Any) -> str:
return "left" if float(block.get_pose().p[0]) < 0 else "right"
def _robotwin_blocks_episode_info(task_name: str, env: Any) -> dict[str, str] | None:
"""Infer the episode-info dict used by RoboTwin's official instruction generator for block ranking."""
if task_name == "blocks_ranking_rgb":
return {
"{A}": "red block",
"{B}": "green block",
"{C}": "blue block",
"{a}": _arm_for_block(env.block1),
"{b}": _arm_for_block(env.block2),
"{c}": _arm_for_block(env.block3),
}
if task_name == "blocks_ranking_size":
return {
"{A}": "large block",
"{B}": "medium block",
"{C}": "small block",
"{a}": _arm_for_block(env.block1),
"{b}": _arm_for_block(env.block2),
"{c}": _arm_for_block(env.block3),
}
return None
def _generate_robotwin_official_instruction(task_name: str, env: Any) -> str:
"""Generate language with RoboTwin's official task templates, matching its eval client."""
fallback = task_name.replace("_", " ")
episode_info = _robotwin_blocks_episode_info(task_name, env)
if episode_info is None:
logger.warning("Official RoboTwin instruction is not implemented for task=%s; using %r.", task_name, fallback)
return fallback
try:
from description.utils.generate_episode_instructions import generate_episode_descriptions
except Exception:
logger.warning("Failed to import RoboTwin official instruction generator; using %r.", fallback, exc_info=True)
return fallback
instruction_type = os.environ.get(OFFICIAL_INSTRUCTION_TYPE_ENV, "seen")
try:
max_descriptions = int(os.environ.get(OFFICIAL_INSTRUCTION_MAX_ENV, "1000000"))
except ValueError:
max_descriptions = 1000000
results = generate_episode_descriptions(task_name, [episode_info], max_descriptions=max_descriptions)
if not results:
logger.warning("RoboTwin generated no official instructions for task=%s; using %r.", task_name, fallback)
return fallback
options = results[0].get(instruction_type) or results[0].get("seen") or results[0].get("unseen")
if not options:
logger.warning(
"RoboTwin generated no %s official instructions for task=%s; using %r.",
instruction_type,
task_name,
fallback,
)
return fallback
return str(np.random.choice(options))
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
DEFAULT_CAMERA_H = 240
DEFAULT_CAMERA_W = 320
@@ -234,6 +342,7 @@ class RoboTwinEnv(gym.Env):
observation_width: int | None = None,
episode_length: int = DEFAULT_EPISODE_LENGTH,
render_mode: str = "rgb_array",
action_mode: str = "joint",
):
super().__init__()
self.task_name = task_name
@@ -241,6 +350,13 @@ class RoboTwinEnv(gym.Env):
self.task_description = task_name.replace("_", " ")
self.episode_index = episode_index
self._reset_stride = n_envs
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
if action_mode not in ("joint", "ee"):
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
self.action_mode = action_mode
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
self._init_eef_pose: np.ndarray | None = None
self.camera_names = list(camera_names)
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
# The YAML-driven lookup is deferred to reset() so construction doesn't
@@ -271,7 +387,7 @@ class RoboTwinEnv(gym.Env):
}
)
self.action_space = spaces.Box(
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
)
def _ensure_env(self) -> None:
@@ -317,6 +433,18 @@ class RoboTwinEnv(gym.Env):
return {"pixels": images, "agent_pos": joint_state}
def _read_eef_pose(self) -> np.ndarray:
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
assert self._env is not None, "_read_eef_pose called before _ensure_env()"
ep = self._env.get_obs()["endpose"]
pose = (
list(ep["left_endpose"])
+ [ep["left_gripper"]]
+ list(ep["right_endpose"])
+ [ep["right_gripper"]]
)
return np.asarray(pose, dtype=np.float64)
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
self._ensure_env()
super().reset(seed=seed)
@@ -330,16 +458,32 @@ class RoboTwinEnv(gym.Env):
self.episode_index += self._reset_stride
self._step_count = 0
use_official_instruction = self.task_name in {"blocks_ranking_rgb", "blocks_ranking_size"}
if _env_flag(OFFICIAL_INSTRUCTION_ENV, default=use_official_instruction):
self.task_description = _generate_robotwin_official_instruction(self.task_name, self._env)
if hasattr(self._env, "set_instruction"):
self._env.set_instruction(instruction=self.task_description)
logger.info("RoboTwin official instruction | task=%s | %s", self.task_name, self.task_description)
else:
self.task_description = self.task_name.replace("_", " ")
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
if self.action_mode == "ee":
self._init_eef_pose = self._read_eef_pose()
obs = self._get_obs()
return obs, {"is_success": False, "task": self.task_name}
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
assert self._env is not None, "step() called before reset()"
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
if action.ndim != 1 or action.shape[0] != self._action_dim:
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
with torch.enable_grad():
if hasattr(self._env, "take_action"):
if self.action_mode == "ee":
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
self._env.take_action(ee_action, action_type="ee")
elif hasattr(self._env, "take_action"):
self._env.take_action(action)
else:
self._env.step(action)
@@ -398,6 +542,7 @@ def _make_env_fns(
observation_height: int,
observation_width: int,
episode_length: int,
action_mode: str = "joint",
) -> list[Callable[[], RoboTwinEnv]]:
"""Return n_envs factory callables for a single task."""
@@ -410,6 +555,7 @@ def _make_env_fns(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
return [partial(_make_one, i) for i in range(n_envs)]
@@ -423,6 +569,7 @@ def create_robotwin_envs(
observation_height: int = DEFAULT_CAMERA_H,
observation_width: int = DEFAULT_CAMERA_W,
episode_length: int = DEFAULT_EPISODE_LENGTH,
action_mode: str = "joint",
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboTwin 2.0 environments.
@@ -473,6 +620,7 @@ def create_robotwin_envs(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
+22
View File
@@ -83,6 +83,28 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("constant_with_warmup")
@dataclass
class ConstantWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Linear warmup followed by a constant learning rate.
Mirrors the ``warmup_constant_lambda`` used by LingBot-VA (upstream ``wan_va/train.py``):
the LR ramps linearly from 0 to the peak over ``num_warmup_steps`` steps, then stays flat.
"""
num_warmup_steps: int = 1000
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
warmup_steps = self.num_warmup_steps or 0
def lr_lambda(current_step):
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
return 1.0
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
+2
View File
@@ -20,6 +20,7 @@ from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig as LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
@@ -44,6 +45,7 @@ __all__ = [
"EO1Config",
"GaussianActorConfig",
"GrootConfig",
"LingBotVAConfig",
"MolmoAct2Config",
"MultiTaskDiTConfig",
"PI0Config",
+15
View File
@@ -49,6 +49,7 @@ from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .lingbot_va.configuration_lingbot_va import LingBotVAConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
elif name == "lingbot_va":
from .lingbot_va.modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
elif policy_type == "lingbot_va":
return LingBotVAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -448,6 +455,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, LingBotVAConfig):
from .lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
processors = make_lingbot_va_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/lingbot_va.mdx
@@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: ``LingBotVAPolicy`` (and the Wan transformer it owns) imports ``diffusers`` as a
# hard dependency at class-definition time (it subclasses diffusers' ModelMixin/ConfigMixin).
# To keep base ``import lerobot`` working without the optional ``lingbot_va`` extra, the
# policy is exposed lazily via module ``__getattr__`` — the heavy import only happens when
# ``LingBotVAPolicy`` is actually accessed (mirroring the lazy import in policies/factory.py).
from .configuration_lingbot_va import LingBotVAConfig
from .processor_lingbot_va import make_lingbot_va_pre_post_processors
__all__ = ["LingBotVAConfig", "LingBotVAPolicy", "make_lingbot_va_pre_post_processors"]
def __getattr__(name):
if name == "LingBotVAPolicy":
from .modeling_lingbot_va import LingBotVAPolicy
return LingBotVAPolicy
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,168 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for the LingBot-VA policy.
LingBot-VA is an autoregressive video-action world-model policy built on the Wan2.2
video-diffusion stack. It interleaves prediction of future video latents and robot
actions in a single dual-stream transformer. See ``docs/source/lingbot_va.mdx`` and the
upstream repository (https://github.com/Robbyant/lingbot-va).
Defaults below match the upstream LIBERO configuration (``wan_va/configs/va_libero_cfg.py``)
and the ``transformer/config.json`` of the released checkpoints.
"""
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION
@PreTrainedConfig.register_subclass("lingbot_va")
@dataclass
class LingBotVAConfig(PreTrainedConfig):
"""Configuration for the native LingBot-VA policy integration in LeRobot."""
# Wan transformer architecture
patch_size: tuple[int, int, int] = (1, 2, 2)
num_attention_heads: int = 24
attention_head_dim: int = 128
in_channels: int = 48
out_channels: int = 48
action_dim: int = 30
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 14336
num_layers: int = 30
cross_attn_norm: bool = True
eps: float = 1e-6
rope_max_seq_len: int = 1024
# "flex" = training only (needs recent torch); inference uses "torch" SDPA or "flashattn".
attn_mode: str = "torch"
# Frozen sub-models (VAE + UMT5 text encoder + tokenizer)
# ~20 GB of frozen weights, NOT bundled in the checkpoint; lazily pulled from this HF repo /
# local dir (must hold diffusers-style ``vae/``, ``text_encoder/``, ``tokenizer/`` sub-folders).
wan_pretrained_path: str = "robbyant/lingbot-va-base"
dtype: str = "bfloat16" # transformer / VAE / text-encoder dtype: "bfloat16", "float16", "float32"
# Frozen UMT5-XXL encoder device; "cpu" frees ~11 GB VRAM (it runs once per episode).
text_encoder_device: str = "cpu"
# Observation cameras (order matters: latents are concatenated on width; LIBERO defaults)
obs_cam_keys: list[str] = field(
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
)
# Undo the LIBERO env processor's extra horizontal flip to match the model's training orientation.
image_hflip: bool = False
# Camera latent layout: "width_concat" (cameras concatenated on width; LIBERO) or
# "robotwin_tshape" (full-res head + half-res wrists in a "T"; RoboTwin).
camera_layout: str = "width_concat"
# Inference hyperparameters (LIBERO defaults)
n_obs_steps: int = 1
height: int = 128
width: int = 128
action_per_frame: int = 4
frame_chunk_size: int = 4
attn_window: int = 30
num_inference_steps: int = 20
video_exec_step: int = -1
action_num_inference_steps: int = 50
guidance_scale: float = 5.0
action_guidance_scale: float = 1.0
snr_shift: float = 5.0
action_snr_shift: float = 0.05
max_sequence_length: int = 512 # UMT5 prompt length
# Subset of the 30-d action space used by the benchmark (LIBERO = 7-DoF). The action
# (un)normalization quantiles live in the checkpoint's ``policy_postprocessor.json``, not here.
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
# Opt-in: VAE-decode predicted video latents to ``self.last_predicted_frames`` for saving MP4s.
save_predicted_video: bool = False
# Normalization: IDENTITY here; images are scaled + VAE-encoded and actions are
# quantile-(un)normalized inside the policy / dedicated processor steps.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Optimizer / scheduler (training; AdamW + warmup-constant per upstream train.py)
optimizer_lr: float = 1e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-4
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 1000
def __post_init__(self):
super().__post_init__()
if self.attn_mode not in ("torch", "flashattn", "flex"):
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
@property
def chunk_size(self) -> int:
"""Number of single-step actions produced per autoregressive chunk."""
return self.frame_chunk_size * self.action_per_frame
@property
def n_action_steps(self) -> int:
"""Number of actions executed before refilling (the whole chunk)."""
return self.chunk_size
def validate_features(self) -> None:
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"LingBot-VA requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if ACTION not in self.output_features:
self.output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION, shape=(len(self.used_action_channel_ids),)
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
# Upstream uses a linear warmup followed by a constant LR (warmup_constant_lambda).
from lerobot.optim.schedulers import ConstantWithWarmupSchedulerConfig
return ConstantWithWarmupSchedulerConfig(num_warmup_steps=self.scheduler_warmup_steps)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,87 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pre/post-processor pipelines for the LingBot-VA policy.
The preprocessor passes inputs through (IDENTITY) and the postprocessor maps the policy's
``[-1, 1]`` actions back to physical units with the built-in ``UnnormalizerProcessorStep``
(QUANTILES) using per-channel q01/q99 restored from the checkpoint.
"""
from typing import Any
import torch
from lerobot.configs.types import FeatureType, NormalizationMode
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_lingbot_va import LingBotVAConfig
def make_lingbot_va_pre_post_processors(
config: LingBotVAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build the pre/post processor pipelines for LingBot-VA."""
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
# Unnormalize actions from [-1, 1] to physical units (QUANTILES) using q01/q99 restored from the checkpoint.
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map={FeatureType.ACTION: NormalizationMode.QUANTILES},
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+54 -278
View File
@@ -32,6 +32,7 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -280,11 +281,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -342,108 +338,30 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
config_filename = kwargs.pop("config_filename", None)
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -451,110 +369,31 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
step_entry["config"] = processor_step.get_config()
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
pipeline_config["steps"].append(step_entry)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
return pipeline_config
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
config["steps"].append(step_entry)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -738,54 +577,12 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
pipeline = cls(
return cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -869,7 +666,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -889,7 +688,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config value to validate (may be non-dict)
loaded_config: The loaded config dictionary (guaranteed non-None)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -903,14 +702,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
)
@classmethod
@@ -972,41 +766,26 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return steps, remaining_override_keys
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
processor_steps.append(processor_step)
# 3. Load step state if available
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
return processor_steps, remaining_override_keys
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
steps.append(step_instance)
return steps, override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1317,7 +1096,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: Any) -> bool:
def _is_processor_config(cls, config: dict) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1368,9 +1147,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
+90 -11
View File
@@ -105,6 +105,7 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
predicted_latents_callback: Callable[[PreTrainedPolicy], None] | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -134,6 +135,9 @@ def rollout(
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
predicted_latents_callback: Optional callback invoked after every ``select_action`` with the policy
itself. World-model policies (e.g. LingBot-VA) stash predicted video latents on
``policy.last_predicted_latents``; this lets the caller concatenate chunks and decode once.
Returns:
The dictionary described above.
"""
@@ -184,6 +188,8 @@ def rollout(
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
if predicted_latents_callback is not None:
predicted_latents_callback(policy)
action = postprocessor(action)
action_transition = {ACTION: action}
@@ -203,12 +209,22 @@ def rollout(
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
if isinstance(final_info, dict):
is_success = final_info.get("is_success", [False] * env.num_envs)
successes = (
is_success.tolist()
if hasattr(is_success, "tolist")
else [bool(is_success)] * env.num_envs
)
successes = final_info["is_success"].tolist()
else:
# Gymnasium < 1.0 returns final_info as a per-env sequence/object array,
# with entries set to a dict only for envs that just finished.
successes = []
for item in final_info:
if isinstance(item, dict) and "is_success" in item:
successes.append(bool(item["is_success"]))
else:
successes.append(False)
elif "is_success" in info:
is_success = info["is_success"]
successes = (
@@ -273,6 +289,7 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
save_predicted_video: bool = False,
) -> dict:
"""
Args:
@@ -291,6 +308,11 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
# World-model policies (e.g. LingBot-VA) opt into predicted-video saving via their config.
save_predicted_video = save_predicted_video or bool(
getattr(getattr(policy, "config", None), "save_predicted_video", False)
)
if not isinstance(policy, PreTrainedPolicy):
exc = ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
@@ -334,6 +356,22 @@ def eval_policy(
if max_episodes_rendered > 0:
video_paths: list[str] = []
if save_predicted_video:
if not videos_dir:
raise ValueError("If save_predicted_video is True, videos_dir must be provided.")
predicted_video_paths: list[str] = []
n_predicted_rendered = 0
# Collect predicted-video latents across a rollout (world-model policies only). The latents are
# concatenated and decoded once after the rollout, matching upstream LingBot-VA's visualization path.
def collect_predicted_latents(policy: PreTrainedPolicy):
latents = getattr(policy, "last_predicted_latents", None)
if latents is not None:
pred_latents.append(
latents.detach().to("cpu") if hasattr(latents, "detach") else torch.as_tensor(latents).cpu()
)
policy.last_predicted_latents = None
if return_episode_data:
episode_data: dict | None = None
@@ -345,6 +383,9 @@ def eval_policy(
if max_episodes_rendered > 0:
ep_frames: list[np.ndarray] = []
if save_predicted_video:
pred_latents: list[torch.Tensor] = []
if start_seed is None:
seeds = None
else:
@@ -361,6 +402,7 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
predicted_latents_callback=collect_predicted_latents if save_predicted_video else None,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -426,6 +468,35 @@ def eval_policy(
threads.append(thread)
n_episodes_rendered += 1
# Maybe save the policy's predicted (imagined) video for this batch's rollout.
if save_predicted_video and len(pred_latents) > 0:
predicted_latent = torch.cat(pred_latents, dim=2)
decoder = getattr(policy, "decode_predicted_latents", None) or getattr(
policy, "_decode_predicted_video", None
)
if decoder is None:
raise AttributeError(
"Policy config requested predicted-video saving, but the policy does not expose "
"`decode_predicted_latents` or `_decode_predicted_video`."
)
predicted_video = decoder(predicted_latent)
if hasattr(predicted_video, "detach"):
predicted_video = predicted_video.detach().to("cpu").numpy()
videos_dir.mkdir(parents=True, exist_ok=True)
predicted_video_path = videos_dir / f"pred_episode_{n_predicted_rendered}.mp4"
predicted_video_paths.append(str(predicted_video_path))
thread = threading.Thread(
target=write_video,
args=(
str(predicted_video_path),
predicted_video,
env.unwrapped.metadata["render_fps"],
),
)
thread.start()
threads.append(thread)
n_predicted_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
)
@@ -469,6 +540,9 @@ def eval_policy(
if max_episodes_rendered > 0:
info["video_paths"] = video_paths
if save_predicted_video:
info["predicted_video_paths"] = predicted_video_paths
return info
@@ -600,9 +674,10 @@ class TaskMetrics(TypedDict):
max_rewards: list[float]
successes: list[bool]
video_paths: list[str]
predicted_video_paths: list[str]
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths", "predicted_video_paths")
def eval_one(
@@ -643,6 +718,7 @@ def eval_one(
max_rewards=[ep["max_reward"] for ep in per_episode],
successes=[ep["success"] for ep in per_episode],
video_paths=task_result.get("video_paths", []),
predicted_video_paths=task_result.get("predicted_video_paths", []),
)
@@ -689,6 +765,7 @@ def run_one(
# ensure we always provide video_paths key to simplify accumulation
if max_episodes_rendered > 0:
metrics.setdefault("video_paths", [])
metrics.setdefault("predicted_video_paths", [])
return task_group, task_id, metrics
@@ -742,11 +819,11 @@ def eval_policy_all(
_append("sum_rewards", metrics.get("sum_rewards"))
_append("max_rewards", metrics.get("max_rewards"))
_append("successes", metrics.get("successes"))
# video_paths is list-like
paths = metrics.get("video_paths", [])
if paths:
group_acc[group]["video_paths"].extend(paths)
overall["video_paths"].extend(paths)
for key in ("video_paths", "predicted_video_paths"):
paths = metrics.get(key, [])
if paths:
group_acc[group][key].extend(paths)
overall[key].extend(paths)
# Choose runner (sequential vs threaded)
task_runner = partial(
@@ -814,6 +891,7 @@ def eval_policy_all(
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
"n_episodes": len(acc["sum_rewards"]),
"video_paths": list(acc["video_paths"]),
"predicted_video_paths": list(acc["predicted_video_paths"]),
}
# overall aggregates
@@ -825,6 +903,7 @@ def eval_policy_all(
"eval_s": time.time() - start_t,
"eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
"video_paths": list(overall["video_paths"]),
"predicted_video_paths": list(overall["predicted_video_paths"]),
}
return {
+9 -29
View File
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -387,21 +384,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
@@ -426,16 +416,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Prepare everything with accelerator
accelerator.wait_for_everyone()
if cfg.dataset.streaming:
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
# data while decoding all of it). Batches are moved to the device manually in the loop below.
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
else:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -475,9 +458,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
if cfg.dataset.streaming:
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
for cam_key in dataset.meta.camera_keys:
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
-150
View File
@@ -1,150 +0,0 @@
"""Acceptance tests for manifest byte-index sidecars.
Run on a compute node (not login-node):
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
"""
from __future__ import annotations
import json
import socket
import pytest
pytest.importorskip("torchcodec")
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
MAX_EPISODES = 64
COMPUTE_NODE = pytest.mark.skipif(
"login" in socket.gethostname(),
reason="run on compute node via srun (see module docstring), not login-node",
)
@pytest.fixture(scope="module")
def byte_index_dir(tmp_path_factory):
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
out = tmp_path_factory.mktemp("byte_index")
meta = LeRobotDatasetMetadata(REPO, revision=REV)
files, episodes, _ = build_byte_index_tables(
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
)
write_byte_index(out, files, episodes, None, merge_existing=False)
return out, meta
@pytest.mark.integration
@COMPUTE_NODE
def test_index_load_fast_and_small(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
out, meta = byte_index_dir
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
assert index.load_time_s < 1.0
assert index.resident_bytes < 1_000_000_000
@pytest.mark.integration
@COMPUTE_NODE
def test_tight_fetch_under_25mb(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
stats = cache.stats.stats_dict()
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
@pytest.mark.integration
@COMPUTE_NODE
def test_in_memory_build_matches_parquet(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
out, meta = byte_index_dir
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
for cam in meta.video_keys:
a = disk.lookup(ep, cam)
b = mem.lookup(ep, cam)
assert a.mdat_offset == b.mdat_offset
assert a.mdat_length == b.mdat_length
assert abs(a.first_pts - b.first_pts) < 1e-6
@pytest.mark.integration
@COMPUTE_NODE
def test_custom_frame_mappings_available(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cam = meta.video_keys[0]
ep = MAX_EPISODES // 2
payload = index.custom_frame_mappings(ep, cam)
assert payload is not None
data = json.loads(payload)
assert len(data["frames"]) > 10
assert any(f["key_frame"] for f in data["frames"])
assert all("pts" in f and "duration" in f for f in data["frames"])
@pytest.mark.integration
@COMPUTE_NODE
def test_metadata_skip_decoder_init(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
assert dec.metadata.num_frames is not None
assert dec.metadata.num_frames > 0
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
@pytest.mark.integration
@COMPUTE_NODE
def test_sparse_decode_produces_frames(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
assert frame.numel() > 0
assert float(frame.float().std()) > 1.0
-24
View File
@@ -114,30 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+95 -30
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
@@ -24,6 +25,52 @@ from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
rng = np.random.default_rng(streaming_ds.seed)
buffer_size = streaming_ds.buffer_size
num_shards = streaming_ds.num_shards
shards_indices = []
for shard_idx in range(num_shards):
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
shard_indices = [item["index"] for item in shard]
shards_indices.append(shard_indices)
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
frames_buffer = []
expected_indices = []
while shard_iterators: # While there are still available shards
available_shard_keys = list(shard_iterators.keys())
if not available_shard_keys:
break
# Call _infinite_generator_over_elements with current available shards (key difference!)
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
try:
frame_index = next(shard_iterators[shard_key])
if len(frames_buffer) == buffer_size:
i = next(buffer_indices_generator)
expected_indices.append(frames_buffer[i])
frames_buffer[i] = frame_index
else:
frames_buffer.append(frame_index)
except StopIteration:
del shard_iterators[shard_key] # Remove exhausted shard
rng.shuffle(frames_buffer)
expected_indices.extend(frames_buffer)
return expected_indices
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
"""Test if are correctly accessed"""
ds_num_frames = 400
@@ -73,9 +120,10 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
[False, True],
)
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
ds_num_frames = 400
ds_num_episodes = 10
buffer_size = 100
seed = 42
n_epochs = 3
@@ -90,17 +138,25 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
)
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
for epoch_indices in epochs:
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
if shuffle:
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
else:
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
expected_indices = get_frames_expected_order(streaming_ds)
for _ in range(n_epochs):
streaming_indices = [frame["index"] for frame in streaming_ds]
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
if shuffle:
assert not frames_match
else:
assert frames_match
@pytest.mark.parametrize(
@@ -108,11 +164,15 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
[False, True],
)
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
ds_num_frames = 100
ds_num_episodes = 10
buffer_size = 10
seed = 42
n_epochs = 3
data_file_size_mb = 0.001
chunks_size = 1
local_path = tmp_path / "test"
@@ -127,21 +187,31 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
chunks_size=chunks_size,
)
def make_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
episode_pool_size=3,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
for _ in range(n_epochs):
streaming_indices = [
frame["index"] for frame in streaming_ds
] # NOTE: this is the same as first_epoch_indices
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
first = [int(frame["index"]) for frame in make_ds()]
again = [int(frame["index"]) for frame in make_ds()]
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
assert first == again, "same seed must reproduce the same order"
if shuffle:
assert not frames_match
else:
assert frames_match
@pytest.mark.parametrize(
@@ -218,11 +288,6 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
check = torch.allclose(left, right) and left.shape == right.shape
else:
# Scalar numerics: streaming yields python floats/ints where map-style yields
# 0-dim tensors (long-standing accepted difference). Compare by value.
check = float(left) == float(right)
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
@@ -1,100 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
"""
import json
import shutil
import subprocess
import sys
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
from tests.fixtures.constants import DUMMY_REPO_ID
WORKER = """
import json, sys
from accelerate import PartialState
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
state = PartialState()
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
)
indices = [int(frame["index"]) for frame in ds]
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
json.dump(payload, f)
"""
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
total_frames = 160
repo_id = f"{DUMMY_REPO_ID}-acc"
root = tmp_path / "ds"
lerobot_dataset_factory(
root=root,
repo_id=repo_id,
total_episodes=8,
total_frames=total_frames,
use_videos=False,
data_files_size_in_mb=0.001,
chunks_size=1,
)
worker = tmp_path / "worker.py"
worker.write_text(WORKER)
out_dir = tmp_path / "out"
out_dir.mkdir()
cmd = [
"accelerate",
"launch",
"--num_processes=2",
"--num_machines=1",
"--mixed_precision=no",
"--dynamo_backend=no",
"--cpu",
str(worker),
str(root),
repo_id,
str(out_dir),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
)
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
rank_sets = [set(p["indices"]) for p in payloads]
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
-430
View File
@@ -1,430 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
prefetching, deterministic fast-forward resume, and schema parity."""
import pytest
import torch
from torch.utils.data import DataLoader
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
factory(
root=root,
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
use_videos=use_videos,
data_files_size_in_mb=0.001,
chunks_size=1,
**kw,
)
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
return [int(frame["index"]) for frame in ds]
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
monkeypatch.delenv("RANK", raising=False)
monkeypatch.delenv("WORLD_SIZE", raising=False)
# No accelerate state, no env -> single process.
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
monkeypatch.setenv("RANK", "3")
monkeypatch.setenv("WORLD_SIZE", "4")
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
repo_id = f"{DUMMY_REPO_ID}-ranks"
total_frames, total_episodes = 200, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
world_size = 2
per_rank = []
for rank in range(world_size):
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=8,
max_num_shards=8,
rank=rank,
world_size=world_size,
)
per_rank.append(set(_stream_indices(ds)))
assert per_rank[0].isdisjoint(per_rank[1]), (
"ranks streamed overlapping frames (duplicate data across GPUs)"
)
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
repo_id = f"{DUMMY_REPO_ID}-workers"
total_frames, total_episodes = 120, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
)
loader = DataLoader(ds, batch_size=None, num_workers=2)
indices = [int(batch["index"]) for batch in loader]
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
"""
repo_id = f"{DUMMY_REPO_ID}-sarm"
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
episode_frames = 300
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
)
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
delta_timestamps = {ACTION: [0.0, horizon_s]}
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=1,
max_num_shards=1,
delta_timestamps=delta_timestamps,
)
horizon_frames = int(round(horizon_s * ds.fps))
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
checked = 0
for frame in ds:
idx = int(frame["index"])
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
if idx + horizon_frames < episode_frames:
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
)
checked += 1
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
repo_id = f"{DUMMY_REPO_ID}-seeds"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
def order(seed):
return _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=seed,
episode_pool_size=4,
max_num_shards=2,
)
)
assert order(0) == order(0), "same seed must reproduce the same order"
assert order(0) != order(1), "different seeds should give different orders"
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
repo_id = f"{DUMMY_REPO_ID}-epochs"
total_frames = 120
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
)
epoch_0 = _stream_indices(ds)
epoch_1 = _stream_indices(ds)
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
assert epoch_0 != epoch_1, "epoch did not reshuffle"
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
repo_id = f"{DUMMY_REPO_ID}-mix"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
)
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
repo_id = f"{DUMMY_REPO_ID}-parity"
map_ds = lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
)
stream_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
)
map_frame = map_ds[0]
stream_frame = next(iter(stream_ds))
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
for key, value in stream_frame.items():
ref = map_frame[key]
if isinstance(value, torch.Tensor):
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
)
elif isinstance(value, str):
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
else:
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
# (a long-standing, accepted difference). Compare by value rather than exact type.
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-vpath"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
seen_paths = []
def fake_decode(video_path, query_ts, *args, **kwargs):
seen_paths.append(str(video_path))
return torch.zeros(len(query_ts), 3, 64, 96)
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
next(iter(ds))
assert seen_paths, "no video decode was issued"
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
repo_id = f"{DUMMY_REPO_ID}-shuf"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ordered = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
)
shuffled = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
)
)
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
assert shuffled != ordered, "shuffle did not decorrelate output order"
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
repo_id = f"{DUMMY_REPO_ID}-native-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=7,
episode_pool_size=2,
frame_shuffle_buffer_size=8,
)
ds = fresh_ds()
it = iter(ds)
consumed = [int(next(it)["index"]) for _ in range(30)]
state = ds.state_dict()
resumed_ds = fresh_ds()
resumed_ds.load_state_dict(state)
rest = [int(frame["index"]) for frame in resumed_ds]
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
# in-flight buffer contents are skipped on resume (documented datasets behavior):
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
covered = len(set(consumed) | set(rest))
max_in_flight = 2 * 30 + 8
assert covered >= total_frames - max_in_flight
assert covered + len(consumed) >= total_frames - max_in_flight
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
import datasets as hf_datasets
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
assert state is not None
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-reshard"
total_episodes = 3
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
# num_shards can reach total_episodes is per-row-group resharding.
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=90,
use_videos=False,
)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
file_to_eps = ds._episode_files()
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
for (chunk_idx, file_idx), eps in file_to_eps.items():
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
assert ds.num_shards == total_episodes
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
a single batch should already span most of the live episodes."""
repo_id = f"{DUMMY_REPO_ID}-frac"
total_episodes = 8
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=240,
use_videos=False,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=0,
episode_pool_size=total_episodes,
max_buffer_input_shards=total_episodes,
)
assert ds.max_buffer_input_shards == total_episodes
batch = 32
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
)
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
stream the whole dataset; the guard must raise instead of silently degrading."""
repo_id = f"{DUMMY_REPO_ID}-divis"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
with pytest.raises(ValueError, match="not divisible by world_size"):
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
)
# Bypassing the guard downgrades it to a warning (no raise).
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=3,
rank=0,
world_size=2,
validate_row_groups=False,
)
assert ds.num_shards == 3
+3 -22
View File
@@ -17,7 +17,6 @@ from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
@@ -36,24 +35,6 @@ from lerobot.datasets.utils import (
)
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
single row group instead.
"""
table = hf_dataset.with_format("arrow")[:]
episode_index = np.asarray(hf_dataset["episode_index"])
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
starts = [0, *boundaries.tolist()]
ends = [*boundaries.tolist(), len(episode_index)]
with pq.ParquetWriter(str(path), table.schema) as writer:
for start, end in zip(starts, ends, strict=True):
writer.write_table(table.slice(start, end - start))
def write_hf_dataset(
hf_dataset: Dataset,
local_dir: Path,
@@ -86,7 +67,7 @@ def write_hf_dataset(
# If the dataset is small enough, write it to a single file.
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
_to_parquet_one_row_group_per_episode(hf_dataset, path)
hf_dataset.to_parquet(path)
return
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
@@ -133,8 +114,8 @@ def write_hf_dataset(
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
# Write the shard to a Parquet file (one row group per episode).
_to_parquet_one_row_group_per_episode(dataset_shard, path)
# Write the shard to a Parquet file.
dataset_shard.to_parquet(path)
# Update chunk and file indices for the next iteration.
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,78 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES
def make_config(**overrides) -> LingBotVAConfig:
kwargs = {"device": "cpu"}
kwargs.update(overrides)
return LingBotVAConfig(**kwargs)
def test_registered_in_choice_registry() -> None:
assert "lingbot_va" in PreTrainedConfig.get_known_choices()
assert PreTrainedConfig.get_choice_class("lingbot_va") is LingBotVAConfig
def test_type_property() -> None:
assert make_config().type == "lingbot_va"
def test_chunk_size_and_action_steps() -> None:
cfg = make_config(frame_chunk_size=4, action_per_frame=4)
assert cfg.chunk_size == 16
assert cfg.n_action_steps == 16
assert cfg.action_delta_indices == list(range(16))
assert cfg.observation_delta_indices is None
assert cfg.reward_delta_indices is None
def test_optimizer_and_scheduler_presets() -> None:
cfg = make_config()
opt = cfg.get_optimizer_preset()
assert opt.lr == cfg.optimizer_lr
sched = cfg.get_scheduler_preset()
assert sched.num_warmup_steps == cfg.scheduler_warmup_steps
def test_validate_features_sets_action_feature() -> None:
cfg = make_config()
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
assert ACTION in cfg.output_features
assert cfg.output_features[ACTION].shape == (len(cfg.used_action_channel_ids),)
def test_validate_features_no_visual_raises() -> None:
cfg = make_config()
cfg.input_features = {}
cfg.output_features = {}
with pytest.raises(ValueError, match="at least one visual input feature"):
cfg.validate_features()
def test_invalid_attn_mode_raises() -> None:
with pytest.raises(ValueError, match="attn_mode"):
make_config(attn_mode="banana")
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import pytest
from lerobot.policies.factory import make_policy_config
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
def test_make_policy_config_returns_lingbot_va() -> None:
cfg = make_policy_config("lingbot_va", device="cpu")
assert isinstance(cfg, LingBotVAConfig)
def test_get_policy_class_resolves_lazily() -> None:
# Importing the policy class pulls in diffusers (Wan2.2 stack); skip if unavailable.
pytest.importorskip("diffusers")
pytest.importorskip("transformers")
from lerobot.policies.factory import get_policy_class
cls = get_policy_class("lingbot_va")
assert cls.name == "lingbot_va"
assert cls.config_class is LingBotVAConfig
+131
View File
@@ -0,0 +1,131 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
FlowMatchScheduler,
data_seq_to_patch,
get_mesh_id,
)
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
assert sch.timesteps.shape == (20,)
diffs = sch.timesteps[1:] - sch.timesteps[:-1]
assert torch.all(diffs <= 0) # decreasing
def test_flow_match_scheduler_step_preserves_shape() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.zeros(1, 48, 4, 8, 16)
out = sch.step(torch.ones_like(sample), sch.timesteps[0], sample)
assert out.shape == sample.shape
def test_flow_match_scheduler_add_noise() -> None:
sch = FlowMatchScheduler(shift=5.0, sigma_min=0.0, extra_one_step=True)
sch.set_timesteps(20)
sample = torch.randn(1, 48, 4, 8, 16)
noise = torch.randn_like(sample)
noisy = sch.add_noise(sample, noise, sch.timesteps[:4], t_dim=2)
assert noisy.shape == sample.shape
def test_get_mesh_id_latent_shape() -> None:
grid = get_mesh_id(4, 8, 16, 0, 1, 0)
assert grid.shape == (4, 4 * 8 * 16) # (f, h, w, stream) x tokens
def test_get_mesh_id_action_shape() -> None:
grid = get_mesh_id(4, 4, 1, 1, 1, 0, action=True)
assert grid.shape == (4, 4 * 4 * 1)
# Action rows for h/w are sentinel -1.
assert torch.all(grid[1] < 0)
assert torch.all(grid[2] < 0)
def test_data_seq_to_patch_roundtrip_shape() -> None:
b, f, h, w, c = 1, 4, 8, 16, 48
seq = torch.arange(b * f * h * w * c, dtype=torch.float32).reshape(b, f * h * w, c)
out = data_seq_to_patch((1, 2, 2), seq, f, h, w, batch_size=b)
assert out.shape == (b, c, f, h, w)
def test_training_step_reduces_loss_tiny_flex() -> None:
"""End-to-end single training step (flow-matching loss -> backward -> AdamW) on a tiny config.
Exercises the flex-attention training path; requires a CUDA GPU with flex-attention support.
"""
if not torch.cuda.is_available():
import pytest
pytest.skip("training step test requires a CUDA GPU (flex-attention)")
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES
cfg = LingBotVAConfig(
attn_mode="flex",
dtype="bfloat16",
in_channels=16,
out_channels=16,
action_dim=8,
text_dim=32,
freq_dim=64,
ffn_dim=64,
num_attention_heads=2,
attention_head_dim=24,
num_layers=2,
frame_chunk_size=2,
action_per_frame=4,
used_action_channel_ids=[0, 1, 2, 3],
obs_cam_keys=[f"{OBS_IMAGES}.image"],
device="cuda",
)
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64))}
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,))}
cfg.validate_features()
policy = LingBotVAPolicy(cfg).to("cuda")
policy.train()
opt = torch.optim.AdamW(policy.get_optim_params(), lr=1e-4)
b, fc, apf = 1, cfg.frame_chunk_size, cfg.action_per_frame
latents = torch.randn(b, cfg.in_channels, fc, 4, 4, device="cuda", dtype=torch.bfloat16)
actions = torch.randn(b, cfg.action_dim, fc, apf, 1, device="cuda", dtype=torch.bfloat16)
amask = torch.zeros(cfg.action_dim, device="cuda")
amask[cfg.used_action_channel_ids] = 1.0
actions_mask = amask.view(1, -1, 1, 1, 1).expand_as(actions)
text_emb = torch.randn(b, cfg.max_sequence_length, cfg.text_dim, device="cuda", dtype=torch.bfloat16)
loss, metrics = policy.training_loss_from_streams(latents, actions, actions_mask, text_emb)
assert torch.isfinite(loss) and {"latent_loss", "action_loss"} <= set(metrics)
loss.backward()
assert any(p.grad is not None and torch.isfinite(p.grad).all() for p in policy.get_optim_params())
opt.step()
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline, UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def _make_config() -> LingBotVAConfig:
cfg = LingBotVAConfig(device="cpu")
cfg.input_features = {f"{OBS_IMAGES}.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128))}
cfg.output_features = {}
cfg.validate_features()
return cfg
def test_make_pre_post_processors_names_and_steps() -> None:
cfg = _make_config()
pre, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
assert pre.name == POLICY_PREPROCESSOR_DEFAULT_NAME
assert post.name == POLICY_POSTPROCESSOR_DEFAULT_NAME
# Actions are unnormalized by the standard built-in quantile unnormalizer.
assert any(isinstance(s, UnnormalizerProcessorStep) for s in post.steps)
def test_freshly_built_postprocessor_is_identity() -> None:
# Without action stats the quantile unnormalizer is a no-op (identity passthrough): the real
# per-benchmark q01/q99 are restored from the saved checkpoint on load, not hardcoded here.
cfg = _make_config()
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
normed = torch.tensor([[0.3, -0.5, 1.0, -1.0, 0.0, 0.7, -0.2]])
assert torch.allclose(post(normed), normed, atol=1e-6)
def test_postprocessor_quantile_unnormalization() -> None:
# QUANTILES unnormalize maps [-1, 1] -> [q01, q99]: -1 -> q01, +1 -> q99.
cfg = _make_config()
q01 = [-1.0, -0.5, 0.0, -1.0, -1.0, -1.0, -1.0]
q99 = [1.0, 0.5, 2.0, 1.0, 1.0, 1.0, 1.0]
stats = {ACTION: {"q01": q01, "q99": q99}}
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=stats)
out_lo = post(torch.full((1, 7), -1.0))
out_hi = post(torch.full((1, 7), 1.0))
assert torch.allclose(out_lo, torch.tensor(q01).unsqueeze(0), atol=1e-4)
assert torch.allclose(out_hi, torch.tensor(q99).unsqueeze(0), atol=1e-4)
def test_postprocessor_stats_survive_save_load(tmp_path) -> None:
# Regression guard for the Hub mechanism: the q01/q99 stats live in the saved post-processor
# state and must round-trip through save_pretrained / from_pretrained.
cfg = _make_config()
q01 = [-0.6, -0.8, -0.9, -0.1, -0.15, -0.25, -1.0]
q99 = [0.9, 0.85, 0.9, 0.17, 0.18, 0.34, 1.0]
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats={ACTION: {"q01": q01, "q99": q99}})
post.save_pretrained(tmp_path)
loaded = PolicyProcessorPipeline.from_pretrained(
tmp_path,
config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
out = loaded(torch.full((1, 7), -1.0))
assert torch.allclose(out, torch.tensor(q01).unsqueeze(0), atol=1e-4)
-220
View File
@@ -24,7 +24,6 @@ from typing import Any
import pytest
import torch
import torch.nn as nn
from safetensors.torch import load_file
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
@@ -175,53 +174,6 @@ class MockStepWithTensorState(ProcessorStep):
return features
class MockLazyTensorStateStep(ProcessorStep):
"""Mock step whose tensor state is not present in constructor config."""
def __init__(
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
):
self.name = name
self.scale = scale
self.tensor_state: torch.Tensor | None = None
if initial_value is not None:
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Return the transition unchanged."""
return transition
def get_config(self) -> dict[str, Any]:
"""Return constructor config while intentionally omitting tensor state."""
return {
"name": self.name,
"scale": self.scale,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only after it has been initialized or loaded."""
if self.tensor_state is None:
return {}
return {"tensor_state": self.tensor_state}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load tensor state."""
self.tensor_state = state["tensor_state"].clone()
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Return features unchanged."""
return features
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
"""Registered lazy tensor state step for registry-based serialization tests."""
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
@@ -668,178 +620,6 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
def test_get_config_matches_saved_json():
"""Test that in-memory config matches the config written by save_pretrained."""
stateless_step = MockStep(name="stateless")
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
in_memory_config = pipeline.get_config()
assert pipeline.get_config() == in_memory_config
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
config_path = Path(tmp_dir) / "memory_pipeline.json"
with open(config_path) as file_pointer:
saved_config = json.load(file_pointer)
assert in_memory_config == saved_config
assert "state_file" not in in_memory_config["steps"][0]
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
def test_state_dict_matches_saved_safetensors():
"""Test that in-memory state matches the safetensors written by save_pretrained."""
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
in_memory_state_dict = pipeline.state_dict()
state_filename = "stateful_pipeline_step_0.safetensors"
state_key = "stateful_pipeline_step_0"
assert set(in_memory_state_dict) == {state_key}
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
in_memory_state_dict[state_key]["tensor_state"].add_(1)
assert stateful_step.tensor_state is not None
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
def test_save_pretrained_still_writes_expected_serialization_files():
"""Test that save_pretrained keeps the existing config and state filenames."""
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
save_path = Path(tmp_dir)
assert (save_path / "policy_preprocessor.json").exists()
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
def test_from_config_round_trips_stateful_pipeline():
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert len(loaded_pipeline) == 1
assert isinstance(loaded_step, MockLazyTensorStateStep)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
def test_from_config_round_trips_registered_stateful_pipeline():
"""Test that from_config resolves registry steps and loads their named tensor state."""
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
assert config["steps"][0]["state_file"] == state_filename
assert set(pipeline_state_dict) == {state_key}
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
assert loaded_step.tensor_state is not None
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
def test_from_config_preserves_state_metadata_for_empty_initial_state():
"""Test in-memory loading when rebuilt steps start without tensor state."""
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.state_dict() == {}
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
loaded_pipeline.load_state_dict(pipeline_state_dict)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
def test_from_config_applies_overrides_before_state_loading():
"""Test that constructor overrides and tensor state loading are separate operations."""
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(
config,
state_dict=pipeline_state_dict,
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.scale == 5.0
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
def test_load_state_dict_raises_on_missing_expected_state():
"""Test loading raises when serialized config expects missing state."""
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
loaded_pipeline.load_state_dict({})
def test_load_state_dict_raises_on_unexpected_extra_state():
"""Test loading raises on unexpected top-level state keys."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
with pytest.raises(KeyError, match="extra"):
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
"""Test stateless in-memory serialization and loading."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
config = pipeline.get_config()
config_without_name = {"steps": config["steps"]}
assert pipeline.state_dict() == {}
assert all("state_file" not in step_entry for step_entry in config["steps"])
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
assert loaded_pipeline.name == "DataProcessorPipeline"
assert loaded_pipeline.state_dict() == {}
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
def test_from_config_rejects_non_dict_config(invalid_config):
"""Test from_config reports invalid top-level config values cleanly."""
with pytest.raises(ValueError, match="not a valid processor configuration"):
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
Generated
+56 -22
View File
@@ -1084,8 +1084,8 @@ wheels = [
[[package]]
name = "datasets"
version = "5.0.1.dev0"
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
version = "4.8.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dill" },
{ name = "filelock" },
@@ -1102,6 +1102,10 @@ dependencies = [
{ name = "tqdm" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
]
[[package]]
name = "debugpy"
@@ -1168,10 +1172,11 @@ wheels = [
[[package]]
name = "diffusers"
version = "0.35.2"
version = "0.36.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "httpx" },
{ name = "huggingface-hub" },
{ name = "importlib-metadata" },
{ name = "numpy" },
@@ -1180,9 +1185,9 @@ dependencies = [
{ name = "requests" },
{ name = "safetensors" },
]
sdist = { url = "https://files.pythonhosted.org/packages/03/68/288ca23c7c05c73e87ffe5efffc282400ac9b017f7a9bb03883f4310ea15/diffusers-0.35.2.tar.gz", hash = "sha256:30ecd552303edfcfe1724573c3918a8462ee3ab4d529bdbd4c0045f763affded", size = 3366711, upload-time = "2025-10-15T04:05:17.213Z" }
sdist = { url = "https://files.pythonhosted.org/packages/88/45/ccb2e2180ddf475a0f931dac6a50346310e4c464ce3cccb8a65d1fc1e16d/diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0", size = 3795088, upload-time = "2025-12-08T10:14:34.255Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/2e/38d9824f8c6bb048c5ba21c6d4da54c29c162a46b58b3ef907a360a76d3e/diffusers-0.35.2-py3-none-any.whl", hash = "sha256:d50d5e74fdd6dcf55e5c1d304bc52cc7c2659abd1752740d736d7b54078b4db5", size = 4121649, upload-time = "2025-10-15T04:05:14.391Z" },
{ url = "https://files.pythonhosted.org/packages/35/50/281f92cb1f83854dbd79b6e958b3bc5018607e2542971d41604ba7a14b2f/diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98", size = 4597884, upload-time = "2025-12-08T10:14:31.979Z" },
]
[[package]]
@@ -1632,6 +1637,18 @@ http = [
{ name = "aiohttp" },
]
[[package]]
name = "ftfy"
version = "6.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "wcwidth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
]
[[package]]
name = "future"
version = "1.0.0"
@@ -1760,7 +1777,7 @@ wheels = [
[[package]]
name = "gym-aloha"
version = "0.1.4"
version = "0.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dm-control" },
@@ -1768,14 +1785,14 @@ dependencies = [
{ name = "imageio", extra = ["ffmpeg"] },
{ name = "mujoco" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
]
[[package]]
name = "gym-hil"
version = "0.1.14"
version = "0.1.13"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "gymnasium" },
@@ -1785,9 +1802,9 @@ dependencies = [
{ name = "pygame" },
{ name = "pynput" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
]
[[package]]
@@ -1877,7 +1894,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
[[package]]
name = "hf-libero"
version = "0.1.4"
version = "0.1.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "bddl", marker = "sys_platform == 'linux'" },
@@ -1898,10 +1915,7 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'linux'" },
{ name = "wandb", marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
[[package]]
name = "hf-xet"
@@ -2695,6 +2709,7 @@ all = [
{ name = "faker" },
{ name = "fastapi" },
{ name = "feetech-servo-sdk" },
{ name = "ftfy" },
{ name = "grpcio" },
{ name = "grpcio-tools" },
{ name = "gym-aloha" },
@@ -2703,6 +2718,7 @@ all = [
{ name = "hebi-py" },
{ name = "hf-libero", marker = "sys_platform == 'linux'" },
{ name = "hidapi" },
{ name = "imageio", extra = ["ffmpeg"] },
{ name = "ipykernel" },
{ name = "jsonlines" },
{ name = "jupyter" },
@@ -2876,6 +2892,9 @@ hopejr = [
{ name = "pygame" },
{ name = "pyserial" },
]
imageio-dep = [
{ name = "imageio", extra = ["ffmpeg"] },
]
intelrealsense = [
{ name = "pyrealsense2", marker = "sys_platform != 'darwin'" },
{ name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin'" },
@@ -2900,6 +2919,13 @@ libero = [
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "transformers" },
]
lingbot-va = [
{ name = "accelerate" },
{ name = "diffusers" },
{ name = "ftfy" },
{ name = "imageio", extra = ["ffmpeg"] },
{ name = "transformers" },
]
matplotlib-dep = [
{ name = "contourpy" },
{ name = "matplotlib" },
@@ -3069,16 +3095,18 @@ xvla = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'lingbot-va'", specifier = ">=1.10.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.36.0" },
{ name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.37.0" },
{ name = "diffusers", marker = "extra == 'lingbot-va'", specifier = ">=0.36.0,<0.37.0" },
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
{ name = "draccus", specifier = "==0.10.0" },
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
@@ -3087,16 +3115,18 @@ requires-dist = [
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "ftfy", marker = "extra == 'lingbot-va'", specifier = ">=6.0.0,<7.0.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
{ name = "imageio", extras = ["ffmpeg"], marker = "extra == 'imageio-dep'", specifier = ">=2.34.0,<3.0.0" },
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
@@ -3127,6 +3157,7 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["eo1"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
@@ -3138,10 +3169,12 @@ requires-dist = [
{ name = "lerobot", extras = ["hardware"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["hilserl"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["hopejr"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["imageio-dep"], marker = "extra == 'lingbot-va'" },
{ name = "lerobot", extras = ["intelrealsense"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["kinematics"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["lekiwi"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["libero"], marker = "sys_platform == 'linux' and extra == 'all'" },
{ name = "lerobot", extras = ["lingbot-va"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'async'" },
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
@@ -3198,6 +3231,7 @@ requires-dist = [
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'lingbot-va'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'molmoact2'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
@@ -3275,7 +3309,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "imageio-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "lingbot-va", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"