mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ecf342d481 | |||
| 1e762d5240 | |||
| 35c3302f4d | |||
| a323ea67b6 | |||
| 7c063c3fbc | |||
| 9cf12c941d | |||
| 4039da81c6 | |||
| b3a28a49f6 |
@@ -67,6 +67,8 @@
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: fastwam
|
||||
title: FastWAM
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
# FastWAM
|
||||
|
||||
FastWAM is a World Action Model policy for robot control. The LeRobot integration exposes FastWAM through the standard policy API so it can be configured with `policy.type=fastwam`, trained with `lerobot-train`, and loaded through the LeRobot pretrained policy interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
FastWAM keeps video modeling during training, but uses direct action prediction at inference time instead of iteratively generating future observations. This LeRobot policy wraps the FastWAM action model, adapts LeRobot batches to FastWAM training samples, and provides the standard processor pipeline for normalization and action postprocessing.
|
||||
|
||||
The implementation initializes the visual world-model components from `Wan-AI/Wan2.2-TI2V-5B` by default and predicts action chunks with shape `[batch, action_horizon, action_dim]`.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=fastwam` configuration through LeRobot
|
||||
- Image, state, action, and language-task batch adaptation
|
||||
- Action chunk inference through `select_action` and `predict_action_chunk`
|
||||
- Checkpoint save/load through the LeRobot policy APIs
|
||||
- Configurable LIBERO gripper action postprocessing
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot from source, then install FastWAM dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[fastwam]"
|
||||
```
|
||||
|
||||
This installs the FastWAM policy extra from `pyproject.toml`: `transformers`,
|
||||
`diffusers`, `ftfy`, and `regex`, plus LeRobot's base dependencies.
|
||||
|
||||
For LIBERO evaluation, install the benchmark dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[fastwam,libero]"
|
||||
```
|
||||
|
||||
This installs both extras. In addition to the FastWAM dependencies above, the
|
||||
`libero` extra installs LeRobot dataset dependencies, `hf-libero` on Linux, and
|
||||
`scipy`.
|
||||
|
||||
FastWAM uses the Wan2.2 TI2V backbone. The default model id is:
|
||||
|
||||
```python
|
||||
policy.model_id=Wan-AI/Wan2.2-TI2V-5B
|
||||
```
|
||||
|
||||
## Data Requirements
|
||||
|
||||
FastWAM expects a LeRobot dataset with:
|
||||
|
||||
- one or more visual observations whose widths concatenate to `policy.image_size[1]`
|
||||
- `observation.state` when `policy.proprio_dim` is not `None`
|
||||
- `action`
|
||||
- a language task instruction through the dataset task field, or precomputed `context` and `context_mask` tensors
|
||||
|
||||
The default visual setup is one image feature named `observation.images.image` with shape `(3, 224, 448)`. If the dataset uses two cameras, configure `policy.input_features` so their heights match `224` and their widths sum to `448`.
|
||||
|
||||
## Usage
|
||||
|
||||
Create a new FastWAM policy with:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-org/your-dataset \
|
||||
--policy.type=fastwam \
|
||||
--policy.action_dim=7 \
|
||||
--policy.proprio_dim=8 \
|
||||
--policy.action_horizon=32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.image_size='[224,448]' \
|
||||
--output_dir=./outputs/fastwam_training \
|
||||
--job_name=fastwam_training \
|
||||
--steps=300000 \
|
||||
--batch_size=8 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
Evaluate an existing LeRobot-format checkpoint on LIBERO-10 with:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
|
||||
--policy.device=cuda \
|
||||
--policy.torch_dtype=float32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.observation_height=224 \
|
||||
--env.observation_width=224 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=0 \
|
||||
--env.episode_length=600
|
||||
```
|
||||
|
||||
For `libero_goal`, `libero_spatial`, and `libero_object`, use
|
||||
`--env.episode_length=300`.
|
||||
|
||||
For real-robot rollout, use the same checkpoint path:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--policy.path=your-org/fastwam-real-robot
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Features
|
||||
|
||||
`policy.image_size` is the size of the concatenated FastWAM image tensor as `(height, width)`. Each configured image feature must have shape `(3, height, camera_width)`, and all camera widths must sum to the configured width.
|
||||
|
||||
### Action Chunking
|
||||
|
||||
`policy.action_horizon` controls the number of future actions supervised during training and predicted during inference. `policy.n_action_steps` controls how many actions are consumed before the policy predicts a fresh chunk. `policy.n_action_steps` must be less than or equal to `policy.action_horizon`.
|
||||
|
||||
### Wan Components
|
||||
|
||||
FastWAM loads the Wan VAE, video DiT, text encoder, and tokenizer from the configured Wan model directory or Hugging Face Hub model id. LeRobot-format FastWAM checkpoints saved by `save_pretrained` also copy the local Wan component files needed by `from_pretrained`.
|
||||
|
||||
### LIBERO Action Toggle
|
||||
|
||||
FastWAM LIBERO checkpoints use `policy.toggle_action_dimensions=[-1]` by
|
||||
default to match the gripper action convention used by the original FastWAM
|
||||
evaluation pipeline:
|
||||
|
||||
```bash
|
||||
--policy.toggle_action_dimensions='[-1]'
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 97.6% | 500 |
|
||||
| libero_object | 99.0% | 500 |
|
||||
| libero_goal | 95.0% | 500 |
|
||||
| libero_10 | 94.0% | 500 |
|
||||
| **average** | **96.4%** | 2000 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300` (1x H20 140 GB).
|
||||
|
||||
## References
|
||||
|
||||
- [Fast-WAM paper](https://arxiv.org/abs/2603.16666)
|
||||
- [Fast-WAM project page](https://yuantianyuan01.github.io/FastWAM/)
|
||||
- [Fast-WAM code](https://github.com/yuantianyuan01/FastWAM)
|
||||
- [Released upstream checkpoints](https://huggingface.co/yuanty/fastwam)
|
||||
- [Wan2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{yuan2026fastwam,
|
||||
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
|
||||
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
|
||||
journal = {arXiv preprint arXiv:2603.16666},
|
||||
year = {2026},
|
||||
url = {https://arxiv.org/abs/2603.16666}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,56 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://arxiv.org/abs/2603.16666
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/yuantianyuan01/FastWAM
|
||||
|
||||
Project page: https://yuantianyuan01.github.io/FastWAM/
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{yuan2026fastwam,
|
||||
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
|
||||
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
|
||||
journal = {arXiv preprint arXiv:2603.16666},
|
||||
year = {2026},
|
||||
url = {https://arxiv.org/abs/2603.16666}
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Base video model: https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B
|
||||
|
||||
Released upstream checkpoints: https://huggingface.co/yuanty/fastwam
|
||||
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| -------------- | -----------: | ---------: |
|
||||
| libero_spatial | 97.6% | 500 |
|
||||
| libero_object | 99.0% | 500 |
|
||||
| libero_goal | 95.0% | 500 |
|
||||
| libero_10 | 94.0% | 500 |
|
||||
| **average** | **96.4%** | 2000 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300`.
|
||||
|
||||
For LIBERO-10, use `--env.task=libero_10 --env.episode_length=600`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
|
||||
--policy.device=cuda \
|
||||
--policy.torch_dtype=float32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 --env.observation_height=256 --env.observation_width=256 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=0 --env.episode_length=600
|
||||
```
|
||||
@@ -1,547 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
|
||||
|
||||
This one file is both the orchestrator and the worker:
|
||||
|
||||
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
|
||||
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
|
||||
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
|
||||
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
|
||||
streaming dataset splits shards per node.
|
||||
|
||||
Scenarios (all single-frame / non-SARM):
|
||||
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
|
||||
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
|
||||
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
|
||||
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
|
||||
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
|
||||
|
||||
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
|
||||
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
|
||||
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
|
||||
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
|
||||
sample latency. Results are written as JSON + CSV under ``--out_dir``.
|
||||
|
||||
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
|
||||
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
|
||||
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
|
||||
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
|
||||
|
||||
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
|
||||
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
|
||||
# it; "main" pins the branch directly and skips that check.
|
||||
MOLMO_REVISION = "main"
|
||||
|
||||
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
|
||||
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
|
||||
# ``cpus`` override the CLI defaults for that leg.
|
||||
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
|
||||
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
|
||||
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
|
||||
SCENARIOS = {
|
||||
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
|
||||
"mmap_local_maxworkers": {
|
||||
"kind": "map",
|
||||
"nodes": 1,
|
||||
"mem": "128G",
|
||||
"time": "01:00:00",
|
||||
"num_workers": 30,
|
||||
"cpus": 92,
|
||||
},
|
||||
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
|
||||
}
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
class _TimedStreaming(StreamingLeRobotDataset):
|
||||
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
|
||||
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
|
||||
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
|
||||
library, to keep the dataset itself instrumentation-free."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fetch_s = 0.0
|
||||
self.decode_s = 0.0
|
||||
|
||||
def __iter__(self):
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
t1 = time.perf_counter()
|
||||
sample = self._finalize_sample(row)
|
||||
t2 = time.perf_counter()
|
||||
self.fetch_s += t1 - t0
|
||||
self.decode_s += t2 - t1
|
||||
yield sample
|
||||
|
||||
|
||||
def select_node_episodes(
|
||||
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
|
||||
) -> list[int]:
|
||||
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
|
||||
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
|
||||
episodes = list(range(meta.total_episodes))
|
||||
from_idx = meta.episodes["dataset_from_index"]
|
||||
to_idx = meta.episodes["dataset_to_index"]
|
||||
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
|
||||
if meta.video_keys:
|
||||
file_columns = {
|
||||
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
|
||||
for key in meta.video_keys
|
||||
}
|
||||
else:
|
||||
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
|
||||
episode_file_ids = [
|
||||
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
|
||||
]
|
||||
groups = group_episodes_by_files(episode_file_ids)
|
||||
if len(groups) < num_partitions:
|
||||
groups = [[i] for i in range(len(episodes))]
|
||||
group_lengths = [sum(lengths[i] for i in g) for g in groups]
|
||||
bins = partition_episodes(group_lengths, num_partitions)
|
||||
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
|
||||
return chosen[:cap] if cap and len(chosen) > cap else chosen
|
||||
|
||||
|
||||
def build_dataset(scenario: str, args: argparse.Namespace):
|
||||
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
|
||||
if scenario.startswith("mmap_local"):
|
||||
if not args.local_root:
|
||||
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
|
||||
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
|
||||
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
|
||||
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
|
||||
return dataset, meta, True, {"loaded_episodes": len(episodes)}
|
||||
|
||||
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
|
||||
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
|
||||
dataset = _TimedStreaming(
|
||||
MOLMO_REPO,
|
||||
revision=MOLMO_REVISION,
|
||||
data_files_root=data_files_root,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
max_buffer_input_shards=args.max_buffer_input_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
|
||||
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
|
||||
validate_row_groups=False,
|
||||
)
|
||||
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
|
||||
|
||||
|
||||
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
|
||||
stage = fetch_s + decode_s
|
||||
return {
|
||||
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
|
||||
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
|
||||
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
|
||||
}
|
||||
|
||||
|
||||
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
|
||||
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
|
||||
it = iter(dataset)
|
||||
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
|
||||
next(it)
|
||||
dataset.fetch_s = dataset.decode_s = 0.0
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n_probe):
|
||||
next(it)
|
||||
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
|
||||
|
||||
|
||||
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
|
||||
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
|
||||
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
|
||||
|
||||
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
|
||||
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
|
||||
rng = random.Random(0)
|
||||
n = len(dataset)
|
||||
fetch_s = getitem_s = 0.0
|
||||
warmed = measured = skipped = attempts = 0
|
||||
while measured < n_probe and attempts < (warmup + n_probe) * 10:
|
||||
attempts += 1
|
||||
i = rng.randrange(n)
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
dataset.get_raw_item(i)
|
||||
t1 = time.perf_counter()
|
||||
dataset[i]
|
||||
t2 = time.perf_counter()
|
||||
except Exception:
|
||||
skipped += 1
|
||||
continue
|
||||
if warmed < warmup:
|
||||
warmed += 1
|
||||
continue
|
||||
fetch_s += t1 - t0
|
||||
getitem_s += t2 - t1
|
||||
measured += 1
|
||||
if skipped:
|
||||
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
|
||||
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
|
||||
|
||||
|
||||
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
device = torch.device(args.device)
|
||||
|
||||
dataset, meta, is_map_style, info = build_dataset(scenario, args)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=args.num_workers > 0,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = []
|
||||
samples = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
if i == args.warmup_batches:
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
samples += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
|
||||
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
|
||||
if steady_start is not None and (now - steady_start) >= args.duration_s:
|
||||
break
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
|
||||
if samples == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
|
||||
"(inspect worker logs; try --num_workers 0 to surface the exception)."
|
||||
)
|
||||
|
||||
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
|
||||
# decodes video in the main process, which must stay decode-clean until the workers have forked
|
||||
# (decoding before fork corrupts the workers' torchcodec state).
|
||||
del loader
|
||||
if is_map_style:
|
||||
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
|
||||
else:
|
||||
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
|
||||
|
||||
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
num_cameras = len(meta.video_keys)
|
||||
results = {
|
||||
"scenario": scenario,
|
||||
"rank": rank,
|
||||
"world_size": world_size,
|
||||
"loader": "map_style" if is_map_style else "streaming",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"episode_pool_size": None if is_map_style else args.episode_pool_size,
|
||||
"max_buffer_input_shards": None
|
||||
if is_map_style
|
||||
else (args.max_buffer_input_shards or args.episode_pool_size),
|
||||
**info,
|
||||
"num_cameras": num_cameras,
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"samples_measured": samples,
|
||||
"steady_window_s": round(steady_elapsed_s, 2),
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
|
||||
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
|
||||
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
|
||||
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
|
||||
if steady_elapsed_s
|
||||
else 0.0,
|
||||
**fetch_decode,
|
||||
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
print(json.dumps(results, indent=2), flush=True)
|
||||
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
|
||||
|
||||
|
||||
def submit_chain(args: argparse.Namespace) -> None:
|
||||
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
|
||||
|
||||
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
|
||||
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
|
||||
"""
|
||||
this_file = Path(__file__).resolve()
|
||||
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
|
||||
logs = Path(repo_dir) / "logs"
|
||||
logs.mkdir(exist_ok=True)
|
||||
run = f"conda run --no-capture-output -n {args.conda_env} python"
|
||||
common = (
|
||||
f"--batch_size {args.batch_size} "
|
||||
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
|
||||
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
|
||||
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
|
||||
)
|
||||
if args.max_buffer_input_shards is not None:
|
||||
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
|
||||
if args.local_root:
|
||||
common += f" --local_root {args.local_root}"
|
||||
env_prefix = "export TOKENIZERS_PARALLELISM=false"
|
||||
sched = []
|
||||
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
|
||||
if os.environ.get(env):
|
||||
sched.append(f"{opt}={os.environ[env]}")
|
||||
|
||||
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
|
||||
prev = ""
|
||||
for scenario in selected:
|
||||
cfg = SCENARIOS[scenario]
|
||||
nw = cfg.get("num_workers", args.num_workers)
|
||||
cpus = cfg.get("cpus", nw + 4)
|
||||
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
|
||||
if cfg["nodes"] > 1:
|
||||
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
|
||||
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
|
||||
body = f"srun --export=ALL bash -c '{inner}'"
|
||||
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
|
||||
else:
|
||||
body = f"cd {repo_dir} && {env_prefix} && {worker}"
|
||||
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
|
||||
cmd = [
|
||||
"sbatch",
|
||||
"--parsable",
|
||||
f"--job-name=dlbench_{scenario}",
|
||||
*node_flags,
|
||||
f"--cpus-per-task={cpus}",
|
||||
f"--mem={cfg['mem']}",
|
||||
f"--time={cfg['time']}",
|
||||
f"--output={logs}/%x-%j.out",
|
||||
*sched,
|
||||
]
|
||||
if prev:
|
||||
cmd.append(f"--dependency=afterany:{prev}")
|
||||
cmd += ["--wrap", body]
|
||||
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
|
||||
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
|
||||
prev = jid
|
||||
|
||||
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument(
|
||||
"--scenario",
|
||||
choices=list(SCENARIOS),
|
||||
default=None,
|
||||
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--scenarios",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
|
||||
)
|
||||
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
|
||||
p.add_argument(
|
||||
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
|
||||
)
|
||||
p.add_argument("--partition_index", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=64)
|
||||
p.add_argument("--num_workers", type=int, default=8)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
p.add_argument(
|
||||
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_buffer_input_shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Concurrently-live random episodes feeding the pool after reshard() "
|
||||
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
|
||||
)
|
||||
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
|
||||
p.add_argument(
|
||||
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
|
||||
)
|
||||
p.add_argument(
|
||||
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
|
||||
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.scenario is None:
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
|
||||
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
submit_chain(args)
|
||||
else:
|
||||
run_scenario(args.scenario, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+9
-14
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -216,7 +216,11 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
fastwam = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -231,9 +235,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -280,6 +284,7 @@ all = [
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
"lerobot[fastwam]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
@@ -333,16 +338,6 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
|
||||
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
|
||||
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
|
||||
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
|
||||
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.datasets.byte_index_builder import (
|
||||
build_byte_index_tables,
|
||||
load_existing_file_ids,
|
||||
write_byte_index,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
|
||||
parser.add_argument("--repo-id", required=True)
|
||||
parser.add_argument("--revision", default=None)
|
||||
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
|
||||
parser.add_argument("--no-keyframes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
output = args.output
|
||||
existing = load_existing_file_ids(output)
|
||||
if existing:
|
||||
logger.info("resuming: %s files already indexed", len(existing))
|
||||
|
||||
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
|
||||
meta,
|
||||
args.data_root,
|
||||
include_keyframes=not args.no_keyframes,
|
||||
workers=args.workers,
|
||||
existing_files=existing,
|
||||
max_episodes=args.max_episodes,
|
||||
)
|
||||
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
|
||||
logger.info("wrote byte index to %s", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,10 +39,6 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 1024
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
|
||||
from .mp4_episode_slice import episode_custom_frame_mappings_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeSliceLookup:
|
||||
global_episode_id: int
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
frame_count: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
avg_fps: float
|
||||
|
||||
@property
|
||||
def fetch_bytes(self) -> int:
|
||||
return self.mdat_length
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileLookup:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
class EpisodeByteIndex:
|
||||
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_dir: str | Path | None,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
num_episodes: int,
|
||||
mmap: bool = True,
|
||||
files_table: pa.Table | None = None,
|
||||
episodes_table: pa.Table | None = None,
|
||||
mp4_by_rel: dict[str, Any] | None = None,
|
||||
):
|
||||
self.index_dir = Path(index_dir) if index_dir is not None else None
|
||||
self.video_keys = list(video_keys)
|
||||
self.num_episodes = num_episodes
|
||||
self.num_cameras = len(video_keys)
|
||||
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
|
||||
self._mp4_by_rel = mp4_by_rel
|
||||
self._frame_mappings_by_gid: dict[int, bytes] = {}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
if files_table is not None and episodes_table is not None:
|
||||
files_tbl, episodes_tbl = files_table, episodes_table
|
||||
else:
|
||||
if self.index_dir is None:
|
||||
raise ValueError("index_dir or in-memory tables required")
|
||||
files_path = self.index_dir / FILES_NAME
|
||||
episodes_path = self.index_dir / EPISODES_NAME
|
||||
if not files_path.exists() or not episodes_path.exists():
|
||||
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
|
||||
files_tbl = pq.read_table(files_path, memory_map=mmap)
|
||||
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
|
||||
|
||||
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
|
||||
self.build_time_s = time.perf_counter() - t0
|
||||
self.load_time_s = self.build_time_s
|
||||
|
||||
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
|
||||
def col(tbl, name: str):
|
||||
array = tbl.column(name).combine_chunks()
|
||||
if pa.types.is_boolean(array.type):
|
||||
return array.to_numpy(zero_copy_only=False)
|
||||
return array.to_numpy()
|
||||
|
||||
self.file_id = col(files_tbl, "file_id")
|
||||
self.file_path = files_tbl.column("file_path").to_pylist()
|
||||
self.file_size = col(files_tbl, "file_size")
|
||||
self.moov_offset = col(files_tbl, "moov_offset")
|
||||
self.moov_length = col(files_tbl, "moov_length")
|
||||
self.header_length = col(files_tbl, "header_length")
|
||||
self.faststart = col(files_tbl, "faststart")
|
||||
self.file_avg_fps = col(files_tbl, "avg_fps")
|
||||
self.codec = files_tbl.column("codec").to_pylist()
|
||||
|
||||
ep = episodes_tbl
|
||||
n = len(ep)
|
||||
gid = col(ep, "global_episode_id")
|
||||
order = np.argsort(gid)
|
||||
self._global_episode_id = gid[order]
|
||||
self._episode_index = col(ep, "episode_index")[order]
|
||||
self._camera_index = col(ep, "camera_index")[order]
|
||||
self._file_id = col(ep, "file_id")[order]
|
||||
self._mdat_offset = col(ep, "mdat_offset")[order]
|
||||
self._mdat_length = col(ep, "mdat_length")[order]
|
||||
self._frame_count = col(ep, "frame_count")[order]
|
||||
self._first_pts = col(ep, "first_pts")[order]
|
||||
self._last_pts = col(ep, "last_pts")[order]
|
||||
|
||||
expected = self.num_episodes * self.num_cameras
|
||||
if n != expected:
|
||||
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
|
||||
|
||||
if self.index_dir is not None:
|
||||
keyframes_path = self.index_dir / KEYFRAMES_NAME
|
||||
if keyframes_path.exists():
|
||||
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
|
||||
self._keyframes_rows = len(kf_tbl)
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
|
||||
self.resident_bytes = int(
|
||||
self._global_episode_id.nbytes
|
||||
+ self._file_id.nbytes
|
||||
+ self._mdat_offset.nbytes
|
||||
+ self._mdat_length.nbytes
|
||||
+ self.file_size.nbytes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
|
||||
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
|
||||
|
||||
@classmethod
|
||||
def from_memory_build(
|
||||
cls,
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = True,
|
||||
) -> EpisodeByteIndex:
|
||||
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
return build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
include_frame_mappings_cache=include_frame_mappings_cache,
|
||||
)
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
row = int(gid)
|
||||
if row < 0 or row >= len(self._global_episode_id):
|
||||
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
|
||||
file_id = int(self._file_id[row])
|
||||
return EpisodeSliceLookup(
|
||||
global_episode_id=gid,
|
||||
file_id=file_id,
|
||||
mdat_offset=int(self._mdat_offset[row]),
|
||||
mdat_length=int(self._mdat_length[row]),
|
||||
frame_count=int(self._frame_count[row]),
|
||||
first_pts=float(self._first_pts[row]),
|
||||
last_pts=float(self._last_pts[row]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> FileLookup:
|
||||
return FileLookup(
|
||||
file_id=file_id,
|
||||
file_path=self.file_path[file_id],
|
||||
file_size=int(self.file_size[file_id]),
|
||||
moov_offset=int(self.moov_offset[file_id]),
|
||||
moov_length=int(self.moov_length[file_id]),
|
||||
header_length=int(self.header_length[file_id]),
|
||||
faststart=bool(self.faststart[file_id]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
codec=self.codec[file_id],
|
||||
)
|
||||
|
||||
def header_byte_range(self, file_id: int) -> tuple[int, int]:
|
||||
length = int(self.header_length[file_id])
|
||||
return 0, max(0, length - 1)
|
||||
|
||||
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
cached = self._frame_mappings_by_gid.get(gid)
|
||||
if cached is not None:
|
||||
return cached
|
||||
if self._mp4_by_rel is None:
|
||||
return None
|
||||
lookup = self.lookup(episode_index, camera_key)
|
||||
rel = self.file_path[lookup.file_id]
|
||||
mp4_index = self._mp4_by_rel.get(rel)
|
||||
if mp4_index is None:
|
||||
return None
|
||||
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
|
||||
self._frame_mappings_by_gid[gid] = payload
|
||||
return payload
|
||||
|
||||
def stats_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"load_time_s": self.load_time_s,
|
||||
"build_time_s": self.build_time_s,
|
||||
"resident_bytes": self.resident_bytes,
|
||||
"frame_mappings_cached": len(self._frame_mappings_by_gid),
|
||||
"mp4_indices_cached": len(self._mp4_by_rel or {}),
|
||||
"num_files": len(self.file_path),
|
||||
"num_episode_rows": len(self._global_episode_id),
|
||||
}
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .mp4_episode_slice import (
|
||||
HEADER_PROBE_BYTES,
|
||||
MAX_HEADER_PROBE_BYTES,
|
||||
average_fps_from_index,
|
||||
episode_keyframes,
|
||||
parse_mp4_file_layout,
|
||||
parse_mp4_index,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BYTE_INDEX_DIR = "meta/byte_index"
|
||||
FILES_NAME = "files.parquet"
|
||||
EPISODES_NAME = "episodes.parquet"
|
||||
KEYFRAMES_NAME = "keyframes.parquet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexedFile:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
def fetch_header_bytes(path: str, file_size: int) -> bytes:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
probe = HEADER_PROBE_BYTES
|
||||
while True:
|
||||
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
|
||||
header = f.read(min(probe, file_size))
|
||||
try:
|
||||
parse_mp4_file_layout(header, file_size)
|
||||
return header
|
||||
except ValueError as exc:
|
||||
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
|
||||
raise
|
||||
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
|
||||
|
||||
|
||||
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
file_size = fs.info(path)["size"]
|
||||
header = fetch_header_bytes(path, file_size)
|
||||
layout = parse_mp4_file_layout(header, file_size)
|
||||
if not layout.faststart:
|
||||
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
|
||||
mp4_index = parse_mp4_index(header, file_size)
|
||||
indexed = IndexedFile(
|
||||
file_id=-1,
|
||||
file_path=rel_path or path,
|
||||
file_size=file_size,
|
||||
moov_offset=layout.moov_offset,
|
||||
moov_length=layout.moov_length,
|
||||
header_length=layout.header_end,
|
||||
faststart=layout.faststart,
|
||||
avg_fps=average_fps_from_index(mp4_index),
|
||||
codec=layout.codec,
|
||||
)
|
||||
return indexed, mp4_index
|
||||
|
||||
|
||||
def build_byte_index_tables(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
file_paths: list[str] | None = None,
|
||||
include_keyframes: bool = True,
|
||||
workers: int = 8,
|
||||
existing_files: dict[str, int] | None = None,
|
||||
max_episodes: int | None = None,
|
||||
return_mp4_indices: bool = False,
|
||||
complete_files_table: bool = False,
|
||||
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
|
||||
"""Build files/episodes/(optional keyframes) Arrow tables."""
|
||||
video_keys = list(meta.video_keys)
|
||||
n_cams = len(video_keys)
|
||||
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
|
||||
rel_paths: set[str] = set()
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
|
||||
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
|
||||
if file_paths is None:
|
||||
file_paths = list(path_by_rel.values())
|
||||
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
|
||||
|
||||
existing_files = existing_files or {}
|
||||
file_meta_by_rel: dict[str, dict[str, Any]] = {}
|
||||
mp4_by_rel: dict[str, Any] = {}
|
||||
next_file_id = max(existing_files.values(), default=-1) + 1
|
||||
|
||||
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
|
||||
if to_index:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
indexed, mp4_index = fut.result()
|
||||
indexed.file_id = next_file_id
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
file_meta_by_rel[rel] = {
|
||||
"file_id": indexed.file_id,
|
||||
"file_path": rel,
|
||||
"file_size": indexed.file_size,
|
||||
"moov_offset": indexed.moov_offset,
|
||||
"moov_length": indexed.moov_length,
|
||||
"header_length": indexed.header_length,
|
||||
"faststart": indexed.faststart,
|
||||
"avg_fps": indexed.avg_fps,
|
||||
"codec": indexed.codec,
|
||||
}
|
||||
existing_files[rel] = indexed.file_id
|
||||
next_file_id += 1
|
||||
|
||||
missing_rels = {
|
||||
str(meta.get_video_file_path(ep, cam))
|
||||
for ep in range(num_episodes)
|
||||
for cam in video_keys
|
||||
} - set(mp4_by_rel.keys())
|
||||
if missing_rels:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
|
||||
for rel in missing_rels
|
||||
if rel not in mp4_by_rel
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
_, mp4_index = fut.result()
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
|
||||
episode_rows: list[dict[str, Any]] = []
|
||||
keyframe_rows: list[dict[str, Any]] = []
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel = str(meta.get_video_file_path(ep_idx, cam))
|
||||
path = f"{data_root.rstrip('/')}/{rel}"
|
||||
if rel not in existing_files:
|
||||
raise KeyError(f"file not indexed: {rel}")
|
||||
mp4_index = mp4_by_rel[rel]
|
||||
ep = meta.episodes[ep_idx]
|
||||
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
|
||||
span = mp4_index.episode_byte_span(from_ts, to_ts)
|
||||
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
|
||||
mdat_length = span.slice_hi - span.slice_lo + 1
|
||||
episode_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"episode_index": ep_idx,
|
||||
"camera_key": cam,
|
||||
"camera_index": cam_to_idx[cam],
|
||||
"file_id": existing_files[rel],
|
||||
"mdat_offset": span.slice_lo,
|
||||
"mdat_length": mdat_length,
|
||||
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
|
||||
"first_pts": from_ts,
|
||||
"last_pts": to_ts,
|
||||
}
|
||||
)
|
||||
if include_keyframes:
|
||||
timescale = mp4_index.timescale
|
||||
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
|
||||
keyframe_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"pts": int(round(pts_s * timescale)),
|
||||
"byte_offset": byte_off,
|
||||
}
|
||||
)
|
||||
|
||||
referenced_rels = {
|
||||
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
|
||||
}
|
||||
if complete_files_table:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
|
||||
elif to_index:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
|
||||
else:
|
||||
files_table = None
|
||||
episodes_table = pa.Table.from_pylist(episode_rows)
|
||||
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
|
||||
if return_mp4_indices:
|
||||
return files_table, episodes_table, keyframes_table, mp4_by_rel
|
||||
return files_table, episodes_table, keyframes_table
|
||||
|
||||
|
||||
def build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = False,
|
||||
):
|
||||
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index import EpisodeByteIndex
|
||||
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
|
||||
meta,
|
||||
data_root,
|
||||
include_keyframes=False,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
return_mp4_indices=True,
|
||||
complete_files_table=True,
|
||||
)
|
||||
index = EpisodeByteIndex(
|
||||
None,
|
||||
video_keys=list(meta.video_keys),
|
||||
num_episodes=num_episodes,
|
||||
files_table=files_tbl,
|
||||
episodes_table=episodes_tbl,
|
||||
mp4_by_rel=mp4_by_rel,
|
||||
)
|
||||
if include_frame_mappings_cache:
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in meta.video_keys:
|
||||
index.custom_frame_mappings(ep_idx, cam)
|
||||
return index
|
||||
|
||||
|
||||
def write_byte_index(
|
||||
output_dir: Path,
|
||||
files_table: pa.Table | None,
|
||||
episodes_table: pa.Table,
|
||||
keyframes_table: pa.Table | None = None,
|
||||
*,
|
||||
merge_existing: bool = True,
|
||||
) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_path = output_dir / FILES_NAME
|
||||
episodes_path = output_dir / EPISODES_NAME
|
||||
keyframes_path = output_dir / KEYFRAMES_NAME
|
||||
|
||||
if merge_existing and files_path.exists() and files_table is not None:
|
||||
prev = pq.read_table(files_path)
|
||||
files_table = pa.concat_tables([prev, files_table])
|
||||
|
||||
if files_table is not None:
|
||||
pq.write_table(files_table, files_path)
|
||||
|
||||
pq.write_table(episodes_table, episodes_path)
|
||||
if keyframes_table is not None:
|
||||
if merge_existing and keyframes_path.exists():
|
||||
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
|
||||
pq.write_table(keyframes_table, keyframes_path)
|
||||
|
||||
|
||||
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
|
||||
files_path = index_dir / FILES_NAME
|
||||
if not files_path.exists():
|
||||
return {}
|
||||
table = pq.read_table(files_path, columns=["file_id", "file_path"])
|
||||
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
|
||||
@@ -945,17 +945,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
|
||||
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
|
||||
from .mp4_episode_slice import SparseMp4Reader
|
||||
from .torchcodec_utils import open_video_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
bytes_fetched: int = 0
|
||||
full_file_fallbacks: int = 0
|
||||
prefetch_submitted: int = 0
|
||||
prefetch_waits: int = 0
|
||||
mdat_slices: int = 0
|
||||
prefix_fetches: int = 0
|
||||
fetch_to_buffer_s: float = 0.0
|
||||
buffer_to_decoder_s: float = 0.0
|
||||
buffer_hit_decoder_s: float = 0.0
|
||||
decode_frame_s: float = 0.0
|
||||
decode_frames: int = 0
|
||||
|
||||
def merge(self, other: CacheStats) -> None:
|
||||
for name in self.__dataclass_fields__:
|
||||
setattr(self, name, getattr(self, name) + getattr(other, name))
|
||||
|
||||
def stats_dict(self) -> dict[str, int | float]:
|
||||
avg_miss = self.bytes_fetched / max(1, self.misses)
|
||||
return {
|
||||
"byte_cache_hits": self.hits,
|
||||
"byte_cache_misses": self.misses,
|
||||
"byte_cache_bytes_fetched": self.bytes_fetched,
|
||||
"byte_cache_bytes_per_miss": avg_miss,
|
||||
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
|
||||
"byte_cache_prefetch_submitted": self.prefetch_submitted,
|
||||
"byte_cache_prefetch_waits": self.prefetch_waits,
|
||||
"byte_cache_mdat_slices": self.mdat_slices,
|
||||
"byte_cache_prefix_fetches": self.prefix_fetches,
|
||||
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
|
||||
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
|
||||
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
|
||||
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EpisodeEntry:
|
||||
decoders: dict[str, Any] = field(default_factory=dict)
|
||||
ready: threading.Event = field(default_factory=threading.Event)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
class RangeFetcher:
|
||||
"""Sequential byte-range GETs via fsspec."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
|
||||
def fetch(self, lo: int, hi: int) -> bytes:
|
||||
if hi < lo:
|
||||
return b""
|
||||
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
|
||||
f.seek(lo)
|
||||
return f.read(hi - lo + 1)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
|
||||
|
||||
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
byte_index: EpisodeByteIndex,
|
||||
max_bytes: int,
|
||||
*,
|
||||
data_root: str,
|
||||
max_prefetch_workers: int = 4,
|
||||
):
|
||||
if max_bytes <= 0:
|
||||
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
|
||||
self.byte_index = byte_index
|
||||
self.max_bytes = max_bytes
|
||||
self.data_root = data_root.rstrip("/")
|
||||
self._bytes_used = 0
|
||||
self._lock = threading.Lock()
|
||||
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
|
||||
self._header_cache: dict[int, bytes] = {}
|
||||
self._fetcher_cache: dict[int, RangeFetcher] = {}
|
||||
self._episodes: dict[int, _EpisodeEntry] = {}
|
||||
self._stats = CacheStats()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
|
||||
self._futures: dict[int, Future] = {}
|
||||
|
||||
@property
|
||||
def stats(self) -> CacheStats:
|
||||
with self._lock:
|
||||
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
|
||||
|
||||
def submit_prefetch(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
if ep_idx in self._episodes or ep_idx in self._futures:
|
||||
return
|
||||
self._stats.prefetch_submitted += 1
|
||||
fut = self._executor.submit(self._prefetch_episode, ep_idx)
|
||||
self._futures[ep_idx] = fut
|
||||
|
||||
def ensure_ready(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
fut = self._futures.pop(ep_idx, None)
|
||||
if fut is not None:
|
||||
with self._lock:
|
||||
self._stats.prefetch_waits += 1
|
||||
fut.result()
|
||||
entry = self._episodes.get(ep_idx)
|
||||
if entry is None:
|
||||
raise KeyError(f"episode {ep_idx} not prefetched")
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
|
||||
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
|
||||
entry = self._episodes[ep_idx]
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
return entry.decoders[video_key]
|
||||
|
||||
def close(self) -> None:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
def _prefetch_episode(self, ep_idx: int) -> None:
|
||||
entry = _EpisodeEntry()
|
||||
self._episodes[ep_idx] = entry
|
||||
try:
|
||||
for cam in self.byte_index.video_keys:
|
||||
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
|
||||
except Exception as exc:
|
||||
entry.error = exc
|
||||
finally:
|
||||
entry.ready.set()
|
||||
|
||||
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
|
||||
key = (ep_idx, cam)
|
||||
with self._lock:
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.hits += 1
|
||||
payload, _ = cached
|
||||
t0 = time.perf_counter()
|
||||
dec = self._decoder_from_payload(payload, ep_idx, cam)
|
||||
with self._lock:
|
||||
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
|
||||
return dec
|
||||
|
||||
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
|
||||
|
||||
with self._lock:
|
||||
self._stats.misses += 1
|
||||
if payload_bytes > self.MAX_BYTES_PER_MISS:
|
||||
logger.warning(
|
||||
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
|
||||
payload_bytes / 1e6,
|
||||
ep_idx,
|
||||
cam,
|
||||
)
|
||||
self._evict_until(payload_bytes)
|
||||
self._cache[key] = (payload, payload_bytes)
|
||||
self._bytes_used += payload_bytes
|
||||
return dec
|
||||
|
||||
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
|
||||
lookup = self.byte_index.lookup(ep_idx, cam)
|
||||
file_info = self.byte_index.file_lookup(lookup.file_id)
|
||||
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
|
||||
t_fetch = time.perf_counter()
|
||||
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
|
||||
lo = lookup.mdat_offset
|
||||
hi = lo + lookup.mdat_length - 1
|
||||
mdat = fetcher.fetch(lo, hi)
|
||||
fetch_s = time.perf_counter() - t_fetch
|
||||
nbytes = len(header) + len(mdat)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += nbytes
|
||||
self._stats.mdat_slices += 1
|
||||
self._stats.fetch_to_buffer_s += fetch_s
|
||||
|
||||
def lazy_fetch(pos: int, end: int) -> bytes:
|
||||
data = fetcher.fetch(pos, end)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += len(data)
|
||||
return data
|
||||
|
||||
reader = SparseMp4Reader(
|
||||
file_size=file_info.file_size,
|
||||
header=header,
|
||||
mdat_lo=lo,
|
||||
mdat_bytes=mdat,
|
||||
lazy_fetch=lazy_fetch,
|
||||
)
|
||||
t_init = time.perf_counter()
|
||||
dec = self._decoder_from_payload(reader, ep_idx, cam)
|
||||
self._validate_decoder(dec, lookup)
|
||||
init_s = time.perf_counter() - t_init
|
||||
with self._lock:
|
||||
self._stats.buffer_to_decoder_s += init_s
|
||||
self._rewind_payload(reader)
|
||||
return reader, nbytes, dec
|
||||
|
||||
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
|
||||
if file_id not in self._fetcher_cache:
|
||||
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
|
||||
self._fetcher_cache[file_id] = RangeFetcher(path)
|
||||
return self._fetcher_cache[file_id]
|
||||
|
||||
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
|
||||
if file_id in self._header_cache:
|
||||
return self._header_cache[file_id]
|
||||
hi = max(0, header_length - 1)
|
||||
header = fetcher.fetch(0, hi)
|
||||
with self._lock:
|
||||
self._header_cache[file_id] = header
|
||||
self._stats.bytes_fetched += len(header)
|
||||
return header
|
||||
|
||||
def _decoder_from_payload(
|
||||
self, payload: SparseMp4Reader, ep_idx: int, cam: str
|
||||
) -> Any:
|
||||
payload.seek(0)
|
||||
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
|
||||
return open_video_decoder(payload, frame_mappings=mappings)
|
||||
|
||||
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
duration = max(0.01, end - begin)
|
||||
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
|
||||
dec.get_frames_played_at([ts]).data
|
||||
|
||||
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
|
||||
payload.seek(0)
|
||||
|
||||
def _evict_until(self, need: int) -> None:
|
||||
while self._bytes_used + need > self.max_bytes and self._cache:
|
||||
_, (_, size) = self._cache.popitem(last=False)
|
||||
self._bytes_used -= size
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -1,555 +0,0 @@
|
||||
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
|
||||
|
||||
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
|
||||
For streaming we fetch only the file header plus the episode's contiguous mdat span
|
||||
instead of the ``0..episode_end`` prefix.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
KEYFRAME_PAD_S = 0.1
|
||||
HEADER_PROBE_BYTES = 4 * 1024 * 1024
|
||||
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4FileLayout:
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
|
||||
|
||||
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
|
||||
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
|
||||
boxes = list(_iter_boxes(header_bytes))
|
||||
moov_offset = mdat_offset = -1
|
||||
moov_length = mdat_size = 0
|
||||
for off, size, typ, _ in boxes:
|
||||
if typ == b"moov" and moov_offset < 0:
|
||||
moov_offset, moov_length = off, size
|
||||
if typ == b"mdat" and mdat_offset < 0:
|
||||
mdat_offset, mdat_size = off, size
|
||||
if moov_offset < 0:
|
||||
raise ValueError("moov box not found in header probe")
|
||||
if mdat_offset < 0:
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
faststart = moov_offset < mdat_offset
|
||||
header_end = mdat_offset
|
||||
codec = _parse_video_codec(header_bytes)
|
||||
return Mp4FileLayout(
|
||||
file_size=file_size,
|
||||
moov_offset=moov_offset,
|
||||
moov_length=moov_length,
|
||||
header_end=header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
faststart=faststart,
|
||||
codec=codec,
|
||||
)
|
||||
|
||||
|
||||
def _parse_video_codec(header_bytes: bytes) -> str:
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
return "unknown"
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
return "unknown"
|
||||
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
|
||||
if stsd is None or len(stsd) < 12:
|
||||
return "unknown"
|
||||
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
|
||||
if len(stsd) >= 12:
|
||||
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
|
||||
return "unknown"
|
||||
|
||||
|
||||
def average_fps_from_index(index: Mp4VideoIndex) -> float:
|
||||
index.ensure_tables()
|
||||
if index.num_samples < 2:
|
||||
return 30.0
|
||||
duration = index.sample_pts(index.num_samples - 1)
|
||||
if duration <= 0:
|
||||
return 30.0
|
||||
return index.num_samples / duration
|
||||
|
||||
|
||||
def episode_custom_frame_mappings_json(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> bytes:
|
||||
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
|
||||
import json
|
||||
|
||||
index.ensure_tables()
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
hi_idx = min(hi_idx, index.num_samples - 1)
|
||||
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
|
||||
|
||||
sync = set(index.sync_samples)
|
||||
timescale = index.timescale
|
||||
# stts deltas for duration per sample (expand stts entries to per-sample delta)
|
||||
sample_deltas: list[int] = []
|
||||
for count, delta in index.stts:
|
||||
sample_deltas.extend([delta] * count)
|
||||
while len(sample_deltas) < index.num_samples:
|
||||
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
|
||||
|
||||
frames = []
|
||||
for idx in range(lo_idx, hi_idx + 1):
|
||||
frames.append(
|
||||
{
|
||||
"pts": int(round(index._pts[idx] * timescale)),
|
||||
"duration": int(sample_deltas[idx]),
|
||||
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
|
||||
}
|
||||
)
|
||||
return json.dumps({"frames": frames}).encode()
|
||||
|
||||
|
||||
def episode_keyframes(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> list[tuple[float, int]]:
|
||||
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
|
||||
index.ensure_tables()
|
||||
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
if not index.sync_samples:
|
||||
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
out: list[tuple[float, int]] = []
|
||||
for sync_one_based in index.sync_samples:
|
||||
idx = sync_one_based - 1
|
||||
if lo_idx <= idx <= hi_idx:
|
||||
out.append((index.sample_pts(idx), index.sample_offset(idx)))
|
||||
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeByteSpan:
|
||||
"""Absolute file byte ranges to fetch for one episode."""
|
||||
|
||||
file_size: int
|
||||
header_end: int
|
||||
slice_lo: int
|
||||
slice_hi: int
|
||||
|
||||
@property
|
||||
def header_bytes(self) -> tuple[int, int]:
|
||||
return 0, self.header_end - 1
|
||||
|
||||
@property
|
||||
def mdat_bytes(self) -> tuple[int, int]:
|
||||
return self.slice_lo, self.slice_hi
|
||||
|
||||
@property
|
||||
def total_fetch_bytes(self) -> int:
|
||||
header = self.header_end
|
||||
mdat = self.slice_hi - self.slice_lo + 1
|
||||
return header + mdat
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4VideoIndex:
|
||||
file_size: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
timescale: int
|
||||
stts: list[tuple[int, int]]
|
||||
stsz: list[int]
|
||||
stsc: list[tuple[int, int, int]]
|
||||
stco: list[int]
|
||||
sync_samples: list[int]
|
||||
_pts: list[float] = field(default_factory=list, repr=False)
|
||||
_offsets: list[int] = field(default_factory=list, repr=False)
|
||||
|
||||
def ensure_tables(self) -> None:
|
||||
if self._pts:
|
||||
return
|
||||
self._pts = _pts_from_stts(self.stts, self.timescale)
|
||||
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.stsz)
|
||||
|
||||
def sample_pts(self, index: int) -> float:
|
||||
self.ensure_tables()
|
||||
return self._pts[index]
|
||||
|
||||
def sample_offset(self, index: int) -> int:
|
||||
self.ensure_tables()
|
||||
index = max(0, min(index, len(self._offsets) - 1))
|
||||
return self._offsets[index]
|
||||
|
||||
def sample_end(self, index: int) -> int:
|
||||
return self.sample_offset(index) + self.stsz[index]
|
||||
|
||||
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
|
||||
self.ensure_tables()
|
||||
n = self.num_samples
|
||||
if n == 0:
|
||||
raise ValueError("MP4 has no video samples")
|
||||
|
||||
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
|
||||
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
|
||||
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
|
||||
hi_idx = min(hi_idx, n - 1)
|
||||
lo_idx = min(lo_idx, n - 1)
|
||||
|
||||
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
|
||||
|
||||
slice_lo = self.sample_offset(lo_idx)
|
||||
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
|
||||
return EpisodeByteSpan(
|
||||
file_size=self.file_size,
|
||||
header_end=self.header_end,
|
||||
slice_lo=slice_lo,
|
||||
slice_hi=min(slice_hi, self.file_size - 1),
|
||||
)
|
||||
|
||||
|
||||
class SparseMp4Reader(io.BufferedIOBase):
|
||||
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_size: int,
|
||||
header: bytes,
|
||||
mdat_lo: int,
|
||||
mdat_bytes: bytes,
|
||||
lazy_fetch: Callable[[int, int], bytes] | None = None,
|
||||
):
|
||||
self._size = file_size
|
||||
self._header = header
|
||||
self._mdat_lo = mdat_lo
|
||||
self._mdat_hi = mdat_lo + len(mdat_bytes)
|
||||
self._mdat = mdat_bytes
|
||||
self._lazy_fetch = lazy_fetch
|
||||
self._pos = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
return self._pos
|
||||
|
||||
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
||||
if whence == io.SEEK_SET:
|
||||
self._pos = offset
|
||||
elif whence == io.SEEK_CUR:
|
||||
self._pos += offset
|
||||
elif whence == io.SEEK_END:
|
||||
self._pos = self._size + offset
|
||||
else:
|
||||
raise ValueError(f"invalid whence: {whence}")
|
||||
self._pos = max(0, min(self._pos, self._size))
|
||||
return self._pos
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size < 0:
|
||||
size = self._size - self._pos
|
||||
if size <= 0:
|
||||
return b""
|
||||
|
||||
out = bytearray()
|
||||
remaining = size
|
||||
pos = self._pos
|
||||
while remaining > 0 and pos < self._size:
|
||||
chunk = self._read_at(pos, remaining)
|
||||
if not chunk:
|
||||
break
|
||||
out.extend(chunk)
|
||||
pos += len(chunk)
|
||||
remaining -= len(chunk)
|
||||
self._pos = pos
|
||||
return bytes(out)
|
||||
|
||||
def _read_at(self, pos: int, n: int) -> bytes:
|
||||
header_len = len(self._header)
|
||||
if pos < header_len:
|
||||
end = min(pos + n, header_len)
|
||||
return self._header[pos:end]
|
||||
|
||||
if self._mdat_lo <= pos < self._mdat_hi:
|
||||
end = min(pos + n, self._mdat_hi)
|
||||
off = pos - self._mdat_lo
|
||||
return self._mdat[off : off + (end - pos)]
|
||||
|
||||
if self._lazy_fetch is not None:
|
||||
with self._lock:
|
||||
end = min(pos + n, self._size)
|
||||
return self._lazy_fetch(pos, end - 1)
|
||||
|
||||
return b"\x00" * min(n, self._size - pos)
|
||||
|
||||
|
||||
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
|
||||
"""Parse moov sample tables from the file header (faststart layout)."""
|
||||
layout = parse_mp4_file_layout(header_bytes, file_size)
|
||||
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
raise ValueError("moov box not found in MP4 header probe")
|
||||
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
raise ValueError("video trak not found in moov")
|
||||
|
||||
mdhd = _find_box_payload(trak, b"mdhd")
|
||||
if mdhd is None:
|
||||
raise ValueError("mdhd not found")
|
||||
timescale = _parse_mdhd_timescale(mdhd)
|
||||
|
||||
stbl = _find_box_payload(trak, b"stbl")
|
||||
if stbl is None:
|
||||
raise ValueError("stbl not found")
|
||||
|
||||
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
|
||||
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
|
||||
stco_payload = _find_box_payload(stbl, b"stco")
|
||||
co64_payload = _find_box_payload(stbl, b"co64")
|
||||
if stco_payload is not None:
|
||||
stco = _parse_stco(stco_payload)
|
||||
elif co64_payload is not None:
|
||||
stco = _parse_co64(co64_payload)
|
||||
else:
|
||||
raise ValueError("stco/co64 not found")
|
||||
|
||||
stss_payload = _find_box_payload(stbl, b"stss")
|
||||
sync_samples = _parse_stss(stss_payload) if stss_payload else []
|
||||
|
||||
return Mp4VideoIndex(
|
||||
file_size=file_size,
|
||||
header_end=layout.header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
timescale=timescale,
|
||||
stts=stts,
|
||||
stsz=stsz,
|
||||
stsc=stsc,
|
||||
stco=stco,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
|
||||
if offset + 8 > len(data):
|
||||
return None
|
||||
size, typ = struct.unpack_from(">I4s", data, offset)
|
||||
header = 8
|
||||
if size == 1:
|
||||
if offset + 16 > len(data):
|
||||
return None
|
||||
size = struct.unpack_from(">Q", data, offset + 8)[0]
|
||||
header = 16
|
||||
elif size == 0:
|
||||
size = len(data) - offset
|
||||
return size, typ, header
|
||||
|
||||
|
||||
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
|
||||
end = end if end is not None else len(data)
|
||||
off = start
|
||||
while off + 8 <= end:
|
||||
hdr = _box_header(data, off)
|
||||
if hdr is None or hdr[0] < hdr[2]:
|
||||
break
|
||||
size, typ, header = hdr
|
||||
yield off, size, typ, data[off + header : off + size]
|
||||
off += size
|
||||
|
||||
|
||||
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(data):
|
||||
if typ == target:
|
||||
return payload
|
||||
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
|
||||
found = _find_box_payload(payload, target)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(moov):
|
||||
if typ != b"trak":
|
||||
continue
|
||||
hdlr = _find_box_payload(payload, b"hdlr")
|
||||
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
|
||||
for off, size, typ, _ in _iter_boxes(header_bytes):
|
||||
if typ == b"mdat":
|
||||
return off, size
|
||||
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
|
||||
|
||||
def _parse_mdhd_timescale(mdhd: bytes) -> int:
|
||||
version = mdhd[0]
|
||||
if version == 0:
|
||||
return struct.unpack_from(">I", mdhd, 12)[0]
|
||||
return struct.unpack_from(">I", mdhd, 20)[0]
|
||||
|
||||
|
||||
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
|
||||
if stts is None:
|
||||
raise ValueError("stts missing")
|
||||
count = struct.unpack_from(">I", stts, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
sample_count, delta = struct.unpack_from(">II", stts, off)
|
||||
out.append((sample_count, delta))
|
||||
off += 8
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stsz(stsz: bytes | None) -> list[int]:
|
||||
if stsz is None:
|
||||
raise ValueError("stsz missing")
|
||||
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
|
||||
if sample_size != 0:
|
||||
return [sample_size] * sample_count
|
||||
off = 12
|
||||
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
|
||||
|
||||
|
||||
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
|
||||
if stsc is None:
|
||||
raise ValueError("stsc missing")
|
||||
count = struct.unpack_from(">I", stsc, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
|
||||
out.append((first_chunk, samples_per_chunk, sample_desc))
|
||||
off += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stco(stco: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stco, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stco, 8))
|
||||
|
||||
|
||||
def _parse_co64(co64: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", co64, 4)[0]
|
||||
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
|
||||
|
||||
|
||||
def _parse_stss(stss: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stss, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stss, 8))
|
||||
|
||||
|
||||
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
|
||||
pts: list[float] = []
|
||||
t = 0
|
||||
for count, delta in stts:
|
||||
for _ in range(count):
|
||||
pts.append(t / timescale)
|
||||
t += delta
|
||||
return pts
|
||||
|
||||
|
||||
def _sample_byte_offsets(
|
||||
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
|
||||
) -> list[int]:
|
||||
if not stsc:
|
||||
stsc = [(1, len(stsz), 1)]
|
||||
|
||||
offsets: list[int] = []
|
||||
chunk_idx = 0
|
||||
sample_idx = 0
|
||||
sc_idx = 0
|
||||
num_chunks = len(stco)
|
||||
|
||||
while chunk_idx < num_chunks and sample_idx < len(stsz):
|
||||
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
if sc_idx + 1 < len(stsc):
|
||||
next_first = stsc[sc_idx + 1][0]
|
||||
chunks_in_entry = next_first - first_chunk
|
||||
else:
|
||||
chunks_in_entry = num_chunks - chunk_idx
|
||||
|
||||
for _ in range(chunks_in_entry):
|
||||
if chunk_idx >= num_chunks:
|
||||
break
|
||||
offset = stco[chunk_idx]
|
||||
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(stsz):
|
||||
break
|
||||
offsets.append(offset)
|
||||
offset += stsz[sample_idx]
|
||||
sample_idx += 1
|
||||
chunk_idx += 1
|
||||
sc_idx += 1
|
||||
|
||||
if len(offsets) < len(stsz):
|
||||
# Pad with last known offset progression for malformed stsc edge cases.
|
||||
last = offsets[-1] if offsets else 0
|
||||
while len(offsets) < len(stsz):
|
||||
idx = len(offsets)
|
||||
offsets.append(last)
|
||||
last += stsz[idx]
|
||||
|
||||
return offsets
|
||||
|
||||
|
||||
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] < ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return min(lo, len(pts) - 1)
|
||||
|
||||
|
||||
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] <= ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return max(0, lo - 1)
|
||||
|
||||
|
||||
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
|
||||
if not sync_samples:
|
||||
return max(0, sample_idx - 2)
|
||||
# stss stores 1-based sample numbers
|
||||
one_based = sample_idx + 1
|
||||
prev = [s for s in sync_samples if s <= one_based]
|
||||
if prev:
|
||||
return prev[-1] - 1
|
||||
return 0
|
||||
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,49 +0,0 @@
|
||||
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torchcodec import FrameBatch, _core as core
|
||||
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
|
||||
|
||||
|
||||
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
data = json.loads(payload)
|
||||
frames = data["frames"]
|
||||
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
|
||||
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
|
||||
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
|
||||
return pts, key, dur
|
||||
|
||||
|
||||
class VideoDecoderLike:
|
||||
"""Minimal VideoDecoder surface used by episode byte cache."""
|
||||
|
||||
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
|
||||
self._decoder = decoder
|
||||
(
|
||||
self.metadata,
|
||||
self.stream_index,
|
||||
self._begin_stream_seconds,
|
||||
self._end_stream_seconds,
|
||||
self._num_frames,
|
||||
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
|
||||
|
||||
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
|
||||
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
|
||||
|
||||
|
||||
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
|
||||
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
|
||||
if frame_mappings is None:
|
||||
decoder = core.create_from_file_like(source, "approximate")
|
||||
core.add_video_stream(decoder)
|
||||
return VideoDecoderLike(decoder)
|
||||
|
||||
mappings = frame_mappings_tensors(frame_mappings)
|
||||
decoder = core.create_from_file_like(source, "custom_frame_mappings")
|
||||
core.add_video_stream(decoder, custom_frame_mappings=mappings)
|
||||
return VideoDecoderLike(decoder)
|
||||
@@ -273,11 +273,7 @@ class VideoDecoderCache:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
@@ -326,7 +322,6 @@ def decode_video_frames_torchcodec(
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
return_uint8: bool = False,
|
||||
episode_decoder: Any | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
@@ -348,10 +343,8 @@ def decode_video_frames_torchcodec(
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
if episode_decoder is not None:
|
||||
decoder = episode_decoder
|
||||
else:
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
@@ -18,6 +18,7 @@ from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
@@ -42,6 +43,7 @@ __all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"EO1Config",
|
||||
"FastWAMConfig",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
|
||||
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "fastwam":
|
||||
from .fastwam.modeling_fastwam import FastWAMPolicy
|
||||
|
||||
return FastWAMPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "fastwam":
|
||||
return FastWAMConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -448,6 +455,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, FastWAMConfig):
|
||||
from .fastwam.processor_fastwam import make_fastwam_pre_post_processors
|
||||
|
||||
processors = make_fastwam_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_fastwam_README.md
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modeling_fastwam import FastWAMPolicy
|
||||
from .processor_fastwam import make_fastwam_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"FastWAMConfig",
|
||||
"FastWAMPolicy",
|
||||
"make_fastwam_pre_post_processors",
|
||||
]
|
||||
@@ -0,0 +1,394 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
PolicyFeature,
|
||||
PreTrainedConfig,
|
||||
)
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
|
||||
|
||||
|
||||
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
|
||||
"patch_size",
|
||||
"in_dim",
|
||||
"hidden_dim",
|
||||
"ffn_dim",
|
||||
"freq_dim",
|
||||
"text_dim",
|
||||
"out_dim",
|
||||
"num_heads",
|
||||
"attn_head_dim",
|
||||
"num_layers",
|
||||
)
|
||||
|
||||
_FASTWAM_ACTION_BASE_COMPAT_KEYS = (
|
||||
"hidden_dim",
|
||||
"ffn_dim",
|
||||
"num_heads",
|
||||
"attn_head_dim",
|
||||
"num_layers",
|
||||
"text_dim",
|
||||
"freq_dim",
|
||||
)
|
||||
|
||||
|
||||
def default_video_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
return {
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"hidden_dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"eps": 1.0e-6,
|
||||
"separated_timestep": True,
|
||||
"use_gradient_checkpointing": False,
|
||||
"video_attention_mask_mode": "first_frame_causal",
|
||||
"action_conditioned": False,
|
||||
"action_dim": action_dim,
|
||||
"action_group_causal_mask_mode": "group_diagonal",
|
||||
"fp32_attention": True,
|
||||
}
|
||||
|
||||
|
||||
def default_action_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
return {
|
||||
"action_dim": action_dim,
|
||||
"hidden_dim": 1024,
|
||||
"ffn_dim": 4096,
|
||||
"num_heads": 24,
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"text_dim": 4096,
|
||||
"freq_dim": 256,
|
||||
"eps": 1.0e-6,
|
||||
"use_gradient_checkpointing": False,
|
||||
"fp32_attention": True,
|
||||
}
|
||||
|
||||
|
||||
def _coerce_enum(enum_cls: type, value: Any) -> Any:
|
||||
if isinstance(value, enum_cls):
|
||||
return value
|
||||
try:
|
||||
return enum_cls(value)
|
||||
except (TypeError, ValueError):
|
||||
return getattr(enum_cls, str(value), value)
|
||||
|
||||
|
||||
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
|
||||
if features is None:
|
||||
return None
|
||||
coerced = {}
|
||||
for name, feature in features.items():
|
||||
if isinstance(feature, PolicyFeature):
|
||||
coerced[name] = feature
|
||||
continue
|
||||
coerced[name] = PolicyFeature(
|
||||
type=_coerce_enum(FeatureType, feature["type"]),
|
||||
shape=tuple(feature["shape"]),
|
||||
)
|
||||
return coerced
|
||||
|
||||
|
||||
def _is_local_model_id(value: str) -> bool:
|
||||
path = Path(value).expanduser()
|
||||
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
|
||||
|
||||
|
||||
def _validate_wan_model_id(value: str, field_name: str) -> str:
|
||||
if value == WAN22_MODEL_ID or _is_local_model_id(value):
|
||||
return value
|
||||
raise ValueError(f"`{field_name}` must be `{WAN22_MODEL_ID}` or an explicit local path, got `{value}`.")
|
||||
|
||||
|
||||
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
|
||||
"""Return whether `fastwam-base` partial weights can initialize this config."""
|
||||
|
||||
default_video_config = default_video_dit_config(config.action_dim)
|
||||
default_action_config = default_action_dit_config(config.action_dim)
|
||||
return all(
|
||||
config.video_dit_config.get(key) == default_video_config.get(key)
|
||||
for key in _FASTWAM_VIDEO_BASE_COMPAT_KEYS
|
||||
) and all(
|
||||
config.action_dit_config.get(key) == default_action_config.get(key)
|
||||
for key in _FASTWAM_ACTION_BASE_COMPAT_KEYS
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("fastwam")
|
||||
@dataclass
|
||||
class FastWAMConfig(PreTrainedConfig):
|
||||
"""Configuration for the FastWAM LeRobot policy.
|
||||
|
||||
Args:
|
||||
action_dim (int): Number of scalar action channels per timestep.
|
||||
proprio_dim (int | None): Number of proprioception channels used as an
|
||||
extra text-context token. `None` disables proprio conditioning.
|
||||
action_horizon (int): Number of actions predicted by one policy call.
|
||||
num_video_frames (int): Raw video sampling window (in dataset frames). The
|
||||
model actually operates on `model_video_frames` frames after subsampling
|
||||
by `action_video_freq_ratio`.
|
||||
action_video_freq_ratio (int): Actions are sampled at this multiple of the
|
||||
video frame rate. Video frames are taken every `action_video_freq_ratio`-th
|
||||
raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames
|
||||
spanning the same time window as `action_horizon` actions (ratio actions
|
||||
per video frame).
|
||||
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
|
||||
context_len (int): Maximum text embedding token length.
|
||||
video_dit_config (dict[str, Any] | None): Wan video expert config.
|
||||
action_dit_config (dict[str, Any] | None): Action expert config.
|
||||
use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT
|
||||
experts (trades compute for memory; propagated into the DiT configs).
|
||||
freeze_video_expert (bool): Freeze the ~5B Wan video expert
|
||||
(`model.video_expert`) so only the action expert + proprio encoder train.
|
||||
Cuts the AdamW optimizer footprint substantially; the video expert keeps its
|
||||
pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the
|
||||
now-gradient-free video loss compute.)
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
action_dim: int = 7
|
||||
proprio_dim: int | None = 8
|
||||
action_horizon: int = 32
|
||||
n_action_steps: int = 32
|
||||
num_video_frames: int = 33
|
||||
action_video_freq_ratio: int = 4
|
||||
image_size: tuple[int, int] = (224, 448)
|
||||
context_len: int = 128
|
||||
model_id: str = WAN22_MODEL_ID
|
||||
tokenizer_model_id: str = WAN22_MODEL_ID
|
||||
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
|
||||
tokenizer_max_len: int = 128
|
||||
load_text_encoder: bool = True
|
||||
mot_checkpoint_mixed_attn: bool = False
|
||||
torch_dtype: str = "bfloat16"
|
||||
prompt_template: str = (
|
||||
"A video recorded from a robot's point of view executing the following instruction: {task}"
|
||||
)
|
||||
num_inference_steps: int = 10
|
||||
inference_seed: int | None = 42
|
||||
rand_device: str = "cpu"
|
||||
text_cfg_scale: float = 1.0
|
||||
negative_prompt: str = ""
|
||||
sigma_shift: float | None = None
|
||||
tiled: bool = False
|
||||
fp32_attention: bool = True
|
||||
use_gradient_checkpointing: bool = False
|
||||
freeze_video_expert: bool = False
|
||||
toggle_action_dimensions: list[int] = field(default_factory=list)
|
||||
video_scheduler: dict[str, float | int] = field(
|
||||
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
|
||||
)
|
||||
action_scheduler: dict[str, float | int] = field(
|
||||
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
|
||||
)
|
||||
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
|
||||
video_dit_config: dict[str, Any] | None = None
|
||||
action_dit_config: dict[str, Any] | None = None
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
input_features: dict[str, PolicyFeature] | None = None
|
||||
output_features: dict[str, PolicyFeature] | None = None
|
||||
optimizer_lr: float = 1.0e-4
|
||||
optimizer_weight_decay: float = 1.0e-2
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self.image_size = tuple(self.image_size)
|
||||
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
|
||||
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
|
||||
self.input_features = _coerce_policy_features(self.input_features)
|
||||
self.output_features = _coerce_policy_features(self.output_features)
|
||||
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
|
||||
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
|
||||
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
|
||||
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
|
||||
self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
|
||||
if self.input_features is None:
|
||||
height, width = self.image_size
|
||||
self.input_features = {
|
||||
"observation.images.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, height, width),
|
||||
)
|
||||
}
|
||||
if self.proprio_dim is not None:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.proprio_dim,),
|
||||
)
|
||||
if self.output_features is None:
|
||||
self.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))}
|
||||
self.validate_features()
|
||||
if self.pretrained_path or self.use_peft or not self.base_model_id:
|
||||
return
|
||||
if not is_fastwam_base_compatible_config(self):
|
||||
return
|
||||
self.pretrained_path = Path(self.base_model_id)
|
||||
self._auto_pretrained_path = True
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
if not getattr(self, "_auto_pretrained_path", False):
|
||||
super()._save_pretrained(save_directory)
|
||||
return
|
||||
|
||||
pretrained_path = self.pretrained_path
|
||||
self.pretrained_path = None
|
||||
try:
|
||||
super()._save_pretrained(save_directory)
|
||||
finally:
|
||||
self.pretrained_path = pretrained_path
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
|
||||
"""Rebuild visual input features from the dataset's real camera keys.
|
||||
|
||||
FastWAM's `__post_init__` installs a synthetic single-image default
|
||||
(`observation.images.image` at full `image_size` width). For datasets
|
||||
with one or more separately-named cameras (e.g. `observation.images.top`,
|
||||
`observation.images.wrist`), this hook — invoked by `make_policy` once the
|
||||
dataset metadata is known — replaces that default with the actual camera
|
||||
keys, each declared at the policy's native per-camera resolution
|
||||
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
|
||||
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
|
||||
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
|
||||
"""
|
||||
image_keys = sorted(
|
||||
key
|
||||
for key, feature in dataset_features.items()
|
||||
if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image")
|
||||
)
|
||||
if not image_keys:
|
||||
return
|
||||
height, total_width = self.image_size
|
||||
per_cam_width = total_width // len(image_keys)
|
||||
new_inputs: dict[str, PolicyFeature] = {
|
||||
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
|
||||
for key in image_keys
|
||||
}
|
||||
if self.proprio_dim is not None and OBS_STATE in dataset_features:
|
||||
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
|
||||
self.input_features = new_inputs
|
||||
self.validate_features()
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.action_dim <= 0:
|
||||
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
|
||||
if self.action_horizon <= 0:
|
||||
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
|
||||
if self.action_video_freq_ratio <= 0:
|
||||
raise ValueError(
|
||||
f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}."
|
||||
)
|
||||
# Video frames are subsampled by action_video_freq_ratio; the resulting model frame
|
||||
# count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the
|
||||
# original FastWAM dataset asserts).
|
||||
if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0:
|
||||
raise ValueError(
|
||||
f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by "
|
||||
f"`action_video_freq_ratio` ({self.action_video_freq_ratio})."
|
||||
)
|
||||
if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0:
|
||||
raise ValueError(
|
||||
f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) "
|
||||
"must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)."
|
||||
)
|
||||
if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0:
|
||||
raise ValueError(
|
||||
f"`action_horizon` ({self.action_horizon}) must be divisible by the number of "
|
||||
f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})."
|
||||
)
|
||||
if not self.image_features:
|
||||
raise ValueError("FastWAM requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("FastWAM requires `action` in output_features.")
|
||||
action_shape = tuple(self.action_feature.shape)
|
||||
if action_shape != (self.action_dim,):
|
||||
raise ValueError(
|
||||
f"FastWAM action feature shape must be ({self.action_dim},), got {action_shape}."
|
||||
)
|
||||
if self.proprio_dim is not None:
|
||||
state_feature = self.robot_state_feature
|
||||
if state_feature is None:
|
||||
raise ValueError("FastWAM requires `observation.state` when `proprio_dim` is set.")
|
||||
state_shape = tuple(state_feature.shape)
|
||||
if state_shape != (self.proprio_dim,):
|
||||
raise ValueError(
|
||||
f"FastWAM state feature shape must be ({self.proprio_dim},), got {state_shape}."
|
||||
)
|
||||
height, width = self.image_size
|
||||
image_width_sum = 0
|
||||
for name, feature in self.image_features.items():
|
||||
shape = tuple(feature.shape)
|
||||
if len(shape) != 3 or shape[0] != 3:
|
||||
raise ValueError(f"FastWAM image feature `{name}` must have shape (3, H, W), got {shape}.")
|
||||
if shape[1] != height:
|
||||
raise ValueError(f"FastWAM image feature `{name}` height must be {height}, got {shape[1]}.")
|
||||
image_width_sum += shape[2]
|
||||
if image_width_sum != width:
|
||||
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
|
||||
|
||||
@property
|
||||
def model_video_frames(self) -> int:
|
||||
"""Number of video frames the model actually operates on, after subsampling the
|
||||
raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9)."""
|
||||
return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# Load the video frames the model is supervised on: the future window subsampled by
|
||||
# action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is
|
||||
# thus `action_video_freq_ratio` actions apart, while actions load at the full rate
|
||||
# (`action_delta_indices` = range(action_horizon)). Returning None would load only the
|
||||
# current frame, making the video target a static repeat (degenerate supervision).
|
||||
return list(range(0, self.num_video_frames, self.action_video_freq_ratio))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.action_horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,540 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
|
||||
# eval episode's action chunks through `infer_joint` so the predicted video latents are
|
||||
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
|
||||
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
|
||||
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
|
||||
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
|
||||
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
|
||||
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
|
||||
_DEBUG_PRED_RATE_DIV = 1
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
"""LeRobot policy wrapper for FastWAM.
|
||||
|
||||
Args:
|
||||
config (FastWAMConfig): FastWAM policy configuration.
|
||||
dataset_stats (dict[str, dict[str, Tensor]] | None): Optional LeRobot
|
||||
dataset statistics passed by the training/evaluation stack.
|
||||
"""
|
||||
|
||||
config_class = FastWAMConfig
|
||||
name = "fastwam"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
|
||||
# dataset feature metadata is already applied to `config` by make_policy upstream,
|
||||
# so we accept and ignore them, matching the other LeRobot policies.
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.model = self._build_core_model(config)
|
||||
if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None:
|
||||
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
|
||||
# so its params drop out of the optimizer (and DDP skips them).
|
||||
self.model.video_expert.requires_grad_(False)
|
||||
# The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so
|
||||
# `video_expert.requires_grad_` no longer reaches them — freeze them via the layers.
|
||||
mot = getattr(self.model, "mot", None)
|
||||
if mot is not None and getattr(mot, "layers", None) is not None:
|
||||
for layer in mot.layers:
|
||||
if "video" in layer.blocks:
|
||||
layer.blocks["video"].requires_grad_(False)
|
||||
self.reset()
|
||||
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
|
||||
# counts only eval-rollout resets (one per episode), not this __init__ one.
|
||||
self._debug_constructed = True
|
||||
self._debug_episode_index = -1
|
||||
self._debug_seen_tasks: set[str] = set()
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video: list | None = None
|
||||
self._debug_pairs: list = []
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
|
||||
"""Shape-aware load that supports cross-embodiment fine-tuning.
|
||||
|
||||
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
|
||||
still raises on a shape mismatch for a shared key. When fine-tuning from a
|
||||
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
|
||||
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
|
||||
proprio encoder legitimately differ in shape. With `strict=False` we drop
|
||||
only those shape-mismatched tensors — leaving them at their freshly
|
||||
initialized values — and load every compatible tensor. With `strict=True`
|
||||
the standard exact-match loader is used.
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
mismatched = []
|
||||
with safe_open(model_file, framework="pt") as f:
|
||||
checkpoint_keys = list(f.keys())
|
||||
for key in checkpoint_keys:
|
||||
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
|
||||
f.get_slice(key).get_shape()
|
||||
):
|
||||
mismatched.append(key)
|
||||
|
||||
if not mismatched:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
if strict:
|
||||
raise RuntimeError(
|
||||
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
|
||||
f"strict=True: {mismatched}"
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
logging.warning(
|
||||
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
|
||||
"every compatible weight: %s",
|
||||
len(mismatched),
|
||||
mismatched,
|
||||
)
|
||||
state_dict = load_file(model_file, device="cpu")
|
||||
for key in mismatched:
|
||||
state_dict.pop(key, None)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if map_location and map_location != "cpu":
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> list[Tensor]:
|
||||
# Return the trainable tensors directly (a single param group). The optimizer
|
||||
# builder wraps these in a param group; returning a bare {"params": [...]} dict
|
||||
# instead would make `list(...)` yield the key string "params".
|
||||
params = (
|
||||
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
|
||||
)
|
||||
proprio_encoder = getattr(self.model, "proprio_encoder", None)
|
||||
if proprio_encoder is not None:
|
||||
params.extend(list(proprio_encoder.parameters()))
|
||||
return [p for p in params if p.requires_grad]
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
|
||||
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
|
||||
# true-vs-pred video if it was a captured one (pairs accumulate only while
|
||||
# capturing), then reset per-episode capture state.
|
||||
if getattr(self, "_debug_constructed", False):
|
||||
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
|
||||
self._save_debug_video()
|
||||
self._debug_episode_index += 1
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video = None
|
||||
self._debug_pairs = []
|
||||
|
||||
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
|
||||
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
|
||||
per-frame `proprio`).
|
||||
|
||||
The LeRobot training loop passes raw `observation.images.*`, a single-step
|
||||
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
|
||||
only the translation `build_inputs` can't: stack the camera frames into a
|
||||
video, encode the prompt with the (frozen) text encoder (mirroring inference,
|
||||
so language-conditioned datasets need no precomputed context), and give proprio
|
||||
the per-frame axis `build_inputs` indexes. All shape/presence validation is
|
||||
left to `build_inputs`, the single authority on the contract.
|
||||
"""
|
||||
sample = dict(batch)
|
||||
if "video" not in sample:
|
||||
sample["video"] = _stack_video_from_images(batch, self.config)
|
||||
if "context" not in sample or "context_mask" not in sample:
|
||||
prompt = _prompt_from_batch(batch=batch, config=self.config)
|
||||
if prompt is None:
|
||||
raise KeyError(
|
||||
"FastWAM training requires a `task`/`prompt` to encode text context, "
|
||||
"or precomputed `context`/`context_mask` in the batch."
|
||||
)
|
||||
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
|
||||
if self.config.proprio_dim is not None and "proprio" not in sample:
|
||||
state = sample.get(OBS_STATE)
|
||||
if state is not None:
|
||||
# LeRobot gives a single-step state [B, D]; build_inputs expects
|
||||
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
|
||||
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
|
||||
return sample
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Compute FastWAM training loss for a LeRobot batch.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): Batch containing FastWAM-ready keys
|
||||
(`video`, `action`, `context`, `context_mask`) or LeRobot keys
|
||||
that can be adapted (`observation.images.*`, `observation.state`,
|
||||
`action`, `action_is_pad`).
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of
|
||||
logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)`
|
||||
contract the LeRobot training loop expects.
|
||||
"""
|
||||
|
||||
sample = self._batch_to_training_sample(batch)
|
||||
loss, metrics = self.model.training_loss(sample)
|
||||
return loss, dict(metrics or {})
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
|
||||
"""Predict a chunk of actions from the current FastWAM observation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): Inference batch with `input_image` or
|
||||
image observation keys, plus `context/context_mask` or `prompt`.
|
||||
|
||||
Returns:
|
||||
Tensor: Action chunk with shape `[B, action_horizon, action_dim]`.
|
||||
"""
|
||||
|
||||
self.eval()
|
||||
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
|
||||
batch_size = _infer_kwargs_batch_size(infer_kwargs)
|
||||
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
|
||||
# run the joint video+action path so the predicted video is VAE-decoded; stash it
|
||||
# so select_action can pair each predicted frame with the real obs that follows.
|
||||
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
|
||||
out = self.model.infer_joint(
|
||||
**infer_kwargs,
|
||||
num_video_frames=self.config.model_video_frames,
|
||||
test_action_with_infer_action=False,
|
||||
)
|
||||
# The decoded rollout has model_video_frames frames spanning the full
|
||||
# action_horizon (action_video_freq_ratio actions per frame); the per-step
|
||||
# pairing indexes into it, so keep all frames.
|
||||
self._debug_last_video = out["video"]
|
||||
action = _action_from_model_output(out)
|
||||
elif batch_size == 1:
|
||||
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
|
||||
else:
|
||||
action = torch.cat(
|
||||
[
|
||||
_action_from_model_output(
|
||||
self.model.infer_action(
|
||||
**_slice_infer_kwargs(infer_kwargs, index=i, batch_size=batch_size)
|
||||
)
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return action.to(device=batch_device(batch), dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
|
||||
self.eval()
|
||||
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
|
||||
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
|
||||
# get the first episode of every task).
|
||||
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
|
||||
self._debug_episode_started = True
|
||||
task = self._debug_task_name(batch)
|
||||
if task not in self._debug_seen_tasks:
|
||||
self._debug_seen_tasks.add(task)
|
||||
self._debug_capturing = True
|
||||
self._debug_episode_task = task
|
||||
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
if capturing:
|
||||
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
|
||||
if capturing:
|
||||
self._debug_capture_pair(batch)
|
||||
self._debug_step_in_chunk += 1
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
|
||||
@staticmethod
|
||||
def _debug_task_name(batch: dict[str, Any]) -> str:
|
||||
task = batch.get("task")
|
||||
if isinstance(task, (list, tuple)):
|
||||
task = task[0] if task else None
|
||||
return str(task) if task else "no_task"
|
||||
|
||||
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
|
||||
video = getattr(self, "_debug_last_video", None)
|
||||
if not video:
|
||||
return
|
||||
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
|
||||
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
|
||||
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
|
||||
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
|
||||
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
|
||||
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
|
||||
idx = min(
|
||||
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
|
||||
len(video) - 1,
|
||||
)
|
||||
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
|
||||
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
|
||||
self._debug_pairs.append(pair)
|
||||
|
||||
@staticmethod
|
||||
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
|
||||
from PIL import ImageDraw
|
||||
|
||||
draw = ImageDraw.Draw(pair)
|
||||
draw.text((3, 3), "true", fill=(255, 255, 0))
|
||||
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
|
||||
|
||||
@staticmethod
|
||||
def _debug_tensor_to_pil(image: Tensor):
|
||||
from PIL import Image
|
||||
|
||||
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
|
||||
|
||||
@staticmethod
|
||||
def _debug_hstack(left, right):
|
||||
from PIL import Image
|
||||
|
||||
if right.height != left.height:
|
||||
right = right.resize((round(right.width * left.height / right.height), left.height))
|
||||
canvas = Image.new("RGB", (left.width + right.width, left.height))
|
||||
canvas.paste(left, (0, 0))
|
||||
canvas.paste(right, (left.width, 0))
|
||||
return canvas
|
||||
|
||||
def _save_debug_video(self) -> None:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.io_utils import write_video
|
||||
|
||||
pairs = getattr(self, "_debug_pairs", None)
|
||||
if not pairs:
|
||||
return
|
||||
out_dir = Path("outputs/fastwam_debug")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
|
||||
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
|
||||
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
|
||||
write_video(path, frames, fps=30)
|
||||
logging.info(
|
||||
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
|
||||
)
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
|
||||
"""Build the FastWAM core for training / inference.
|
||||
|
||||
Only the trainable parts (the MoT DiT and the proprio encoder) are
|
||||
materialized empty here and then filled from the policy's
|
||||
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
|
||||
and UMT5 text encoder are loaded with their real weights from the
|
||||
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
|
||||
across checkpoints) and are intentionally excluded from `model.safetensors`
|
||||
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
|
||||
"""
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
device = config.device
|
||||
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
|
||||
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
|
||||
mot = MoT(
|
||||
mixtures={"video": video_expert, "action": action_expert},
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
|
||||
if config.load_text_encoder
|
||||
else None
|
||||
)
|
||||
return FastWAM(
|
||||
video_expert=video_expert,
|
||||
action_expert=action_expert,
|
||||
mot=mot,
|
||||
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
|
||||
text_dim=int(config.video_dit_config["text_dim"]),
|
||||
proprio_dim=config.proprio_dim,
|
||||
device=device,
|
||||
torch_dtype=dtype,
|
||||
video_train_shift=float(config.video_scheduler["train_shift"]),
|
||||
video_infer_shift=float(config.video_scheduler["infer_shift"]),
|
||||
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
|
||||
action_train_shift=float(config.action_scheduler["train_shift"]),
|
||||
action_infer_shift=float(config.action_scheduler["infer_shift"]),
|
||||
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
|
||||
loss_lambda_video=float(config.loss["lambda_video"]),
|
||||
loss_lambda_action=float(config.loss["lambda_action"]),
|
||||
)
|
||||
|
||||
|
||||
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
|
||||
return {
|
||||
"prompt": _prompt_from_batch(batch=batch, config=config),
|
||||
"input_image": _input_image_from_batch(batch, config),
|
||||
"action_horizon": config.action_horizon,
|
||||
"proprio": batch.get("proprio", batch.get(OBS_STATE)),
|
||||
"context": batch.get("context"),
|
||||
"context_mask": batch.get("context_mask"),
|
||||
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
|
||||
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
|
||||
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
|
||||
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
|
||||
"seed": batch.get("seed", config.inference_seed),
|
||||
"rand_device": batch.get("rand_device", config.rand_device),
|
||||
"tiled": bool(batch.get("tiled", config.tiled)),
|
||||
}
|
||||
|
||||
|
||||
def _prompt_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Any:
|
||||
prompt = batch.get("prompt")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
|
||||
task = batch.get("task")
|
||||
if task is None:
|
||||
return None
|
||||
if isinstance(task, str):
|
||||
return config.prompt_template.format(task=task)
|
||||
if isinstance(task, (list, tuple)):
|
||||
return [config.prompt_template.format(task=str(item)) for item in task]
|
||||
return config.prompt_template.format(task=str(task))
|
||||
|
||||
|
||||
def _action_from_model_output(output: Any) -> Tensor:
|
||||
action = output["action"] if isinstance(output, dict) else output
|
||||
if action.ndim == 2:
|
||||
action = action.unsqueeze(0)
|
||||
return action
|
||||
|
||||
|
||||
def _infer_kwargs_batch_size(infer_kwargs: dict[str, Any]) -> int:
|
||||
image = infer_kwargs["input_image"]
|
||||
if not isinstance(image, Tensor):
|
||||
raise TypeError(f"`input_image` must be a tensor, got {type(image).__name__}.")
|
||||
if image.ndim == 3:
|
||||
return 1
|
||||
if image.ndim == 4:
|
||||
return int(image.shape[0])
|
||||
raise ValueError(f"`input_image` must be [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
|
||||
|
||||
|
||||
def _slice_infer_kwargs(infer_kwargs: dict[str, Any], *, index: int, batch_size: int) -> dict[str, Any]:
|
||||
return {
|
||||
key: _slice_infer_value(value, index=index, batch_size=batch_size)
|
||||
for key, value in infer_kwargs.items()
|
||||
}
|
||||
|
||||
|
||||
def _slice_infer_value(value: Any, *, index: int, batch_size: int) -> Any:
|
||||
if isinstance(value, Tensor) and value.ndim > 0 and value.shape[0] == batch_size:
|
||||
return value[index : index + 1]
|
||||
if isinstance(value, (list, tuple)) and len(value) == batch_size:
|
||||
return value[index]
|
||||
return value
|
||||
|
||||
|
||||
def _dtype_from_name(name: str) -> torch.dtype:
|
||||
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||||
if name not in dtype_map:
|
||||
raise ValueError(f"Unsupported torch dtype `{name}`.")
|
||||
return dtype_map[name]
|
||||
|
||||
|
||||
def batch_device(batch: dict[str, Any]) -> torch.device:
|
||||
for value in batch.values():
|
||||
if isinstance(value, Tensor):
|
||||
return value.device
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
|
||||
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
|
||||
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
|
||||
if not image_keys:
|
||||
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
|
||||
images = [batch[key] for key in image_keys]
|
||||
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
|
||||
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
|
||||
if image.ndim == 4:
|
||||
# [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time.
|
||||
image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1)
|
||||
elif image.ndim == 5:
|
||||
# [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W].
|
||||
image = image.permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.")
|
||||
return image
|
||||
|
||||
|
||||
def _input_image_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
if "input_image" in batch:
|
||||
return _prepare_infer_image(batch["input_image"], config)
|
||||
video = batch.get("video")
|
||||
if video is None:
|
||||
video = _stack_video_from_images(batch, config)
|
||||
if video.ndim == 5:
|
||||
return _prepare_infer_image(video[:, :, 0], config)
|
||||
if video.ndim == 4:
|
||||
return _prepare_infer_image(video, config)
|
||||
raise ValueError(f"Cannot build input image from tensor with shape {tuple(video.shape)}.")
|
||||
|
||||
|
||||
def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
|
||||
|
||||
target_h, target_w = config.image_size
|
||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||
raise ValueError(
|
||||
"FastWAM policy expects preprocessed image tensors with shape "
|
||||
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
|
||||
"Run the FastWAM preprocessor before calling the policy."
|
||||
)
|
||||
return image
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,183 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
|
||||
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
|
||||
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
|
||||
|
||||
FastWAM loads a per-camera video stack, so image observations arrive as
|
||||
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
|
||||
single leading batch dim (resize raises on 5-D input), so we flatten any leading
|
||||
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
|
||||
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
|
||||
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
|
||||
# matches on the `"image"` substring, so set these aside and restore them untouched rather
|
||||
# than letting it try to resize a mask.
|
||||
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
|
||||
leads: dict[str, tuple] = {}
|
||||
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
|
||||
for key, img in list(flat_input.items()):
|
||||
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
|
||||
leads[key] = tuple(img.shape[:-3])
|
||||
flat_input[key] = img.reshape(-1, *img.shape[-3:])
|
||||
processed = super().observation(flat_input)
|
||||
out = dict(processed)
|
||||
for key, lead in leads.items():
|
||||
im = processed[key]
|
||||
out[key] = im.reshape(*lead, *im.shape[-3:])
|
||||
out.update(pad_keys)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
|
||||
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
|
||||
"""Apply FastWAM LIBERO toggle semantics to configured action dimensions."""
|
||||
|
||||
toggle_dimensions: list[int]
|
||||
|
||||
def action(self, action: PolicyAction) -> PolicyAction:
|
||||
if not self.toggle_dimensions:
|
||||
return action
|
||||
processed_action = action.clone()
|
||||
action_dim = int(processed_action.shape[-1])
|
||||
for dim in self.toggle_dimensions:
|
||||
resolved_dim = dim if dim >= 0 else action_dim + dim
|
||||
if resolved_dim < 0 or resolved_dim >= action_dim:
|
||||
raise ValueError(
|
||||
f"FastWAM action toggle dimension {dim} is out of bounds for action dim {action_dim}."
|
||||
)
|
||||
value = processed_action[..., resolved_dim]
|
||||
value = value * 2.0 - 1.0
|
||||
processed_action[..., resolved_dim] = torch.sign(-value)
|
||||
return processed_action
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"toggle_dimensions": self.toggle_dimensions}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def make_fastwam_pre_post_processors(
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Create LeRobot pre- and post-processing pipelines for FastWAM.
|
||||
|
||||
Args:
|
||||
config (FastWAMConfig): Policy configuration controlling device and
|
||||
normalization feature metadata.
|
||||
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): Optional
|
||||
LeRobot dataset statistics used by normalization processors.
|
||||
|
||||
Returns:
|
||||
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: Input and
|
||||
output processor pipelines discoverable by LeRobot.
|
||||
"""
|
||||
|
||||
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
|
||||
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
|
||||
for key, feature in config.input_features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
channels = int(feature.shape[0])
|
||||
normalization_stats[key] = {
|
||||
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# resize visual inputs to match model expected input size, if necessary
|
||||
visual_shapes = [
|
||||
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
|
||||
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
|
||||
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
*resize_steps,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
]
|
||||
if config.toggle_action_dimensions:
|
||||
output_steps.append(
|
||||
FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
# Wan2.2 Upstream Subset
|
||||
|
||||
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
|
||||
|
||||
- Upstream repository: https://github.com/Wan-Video/Wan2.2
|
||||
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
|
||||
- License: Apache-2.0, matching the license in `LICENSE.txt` from the upstream repository
|
||||
|
||||
Copied files:
|
||||
|
||||
- `wan/modules/attention.py`
|
||||
- `wan/modules/model.py`
|
||||
- `wan/modules/__init__.py`
|
||||
- `wan/utils/fm_solvers.py`
|
||||
- `wan/utils/__init__.py`
|
||||
|
||||
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
|
||||
UMT5 text encoder, and tokenizer are no longer vendored — they come from
|
||||
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
|
||||
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
|
||||
|
||||
Current FastWAM adapters that directly reuse this vendored subset:
|
||||
|
||||
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
|
||||
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
|
||||
__all__ = [
|
||||
"WanModel",
|
||||
"flash_attention",
|
||||
]
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
"flash_attention",
|
||||
"attention",
|
||||
]
|
||||
|
||||
|
||||
def flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
k: [B, Lk, Nk, C1].
|
||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||
q_lens: [B].
|
||||
k_lens: [B].
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
assert q.device.type == "cuda" and q.size(-1) <= 256
|
||||
|
||||
# params
|
||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
# preprocess query
|
||||
if q_lens is None:
|
||||
q = half(q.flatten(0, 1))
|
||||
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
||||
else:
|
||||
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens, strict=False)]))
|
||||
|
||||
# preprocess key, value
|
||||
if k_lens is None:
|
||||
k = half(k.flatten(0, 1))
|
||||
v = half(v.flatten(0, 1))
|
||||
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
|
||||
else:
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens, strict=False)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens, strict=False)]))
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.", stacklevel=2)
|
||||
|
||||
# apply attention
|
||||
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||
x = flash_attn_interface.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
deterministic=deterministic,
|
||||
)[0].unflatten(0, (b, lq))
|
||||
else:
|
||||
assert FLASH_ATTN_2_AVAILABLE
|
||||
x = flash_attn.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
).unflatten(0, (b, lq))
|
||||
|
||||
# output
|
||||
return x.type(out_dtype)
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
fa_version=None,
|
||||
):
|
||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||
return flash_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
q_lens=q_lens,
|
||||
k_lens=k_lens,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
q_scale=q_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
dtype=dtype,
|
||||
version=fa_version,
|
||||
)
|
||||
else:
|
||||
if q_lens is not None or k_lens is not None:
|
||||
warnings.warn(
|
||||
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.",
|
||||
stacklevel=2,
|
||||
)
|
||||
attn_mask = None
|
||||
|
||||
q = q.transpose(1, 2).to(dtype)
|
||||
k = k.transpose(1, 2).to(dtype)
|
||||
v = v.transpose(1, 2).to(dtype)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
@@ -0,0 +1,519 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from .attention import flash_attention
|
||||
|
||||
__all__ = ["WanModel"]
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float64)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
|
||||
freqs_i = torch.cat(
|
||||
[
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||
],
|
||||
dim=-1,
|
||||
).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, seq_lens, grid_sizes, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
seq_lens(Tensor): Shape [B]
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanCrossAttention(WanSelfAttention):
|
||||
def forward(self, x, context, context_lens):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
context_lens(Tensor): Shape [B]
|
||||
"""
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)
|
||||
)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, L1, 6, C]
|
||||
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
||||
assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs
|
||||
)
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
x = x + y * e[2].squeeze(2)
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, context_lens, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
x = x + y * e[5].squeeze(2)
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, context_lens, e)
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, e):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, L1, C]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
|
||||
_no_split_modules = ["WanAttentionBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
model_type="t2v",
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = torch.cat(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
y=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == "i2v":
|
||||
assert y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
|
||||
|
||||
# time embeddings
|
||||
if t.dim() == 1:
|
||||
t = t.expand(t.size(0), seq_len)
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
bt = t.size(0)
|
||||
t = t.flatten()
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()
|
||||
)
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
|
||||
)
|
||||
|
||||
# arguments
|
||||
kwargs = {
|
||||
"e": e0,
|
||||
"seq_lens": seq_lens,
|
||||
"grid_sizes": grid_sizes,
|
||||
"freqs": self.freqs,
|
||||
"context": context,
|
||||
"context_lens": context_lens,
|
||||
}
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return [u.float() for u in x]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Tensor):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for u, v in zip(x, grid_sizes.tolist(), strict=False):
|
||||
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
|
||||
u = torch.einsum("fhwpqrc->cfphqwr", u)
|
||||
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size, strict=False)])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
r"""
|
||||
Initialize model parameters using Xavier initialization.
|
||||
"""
|
||||
|
||||
# basic init
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
# init embeddings
|
||||
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
||||
for m in self.text_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=0.02)
|
||||
for m in self.time_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=0.02)
|
||||
|
||||
# init output layer
|
||||
nn.init.zeros_(self.head.head.weight)
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .fm_solvers import get_sampling_sigmas
|
||||
|
||||
__all__ = [
|
||||
"get_sampling_sigmas",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
return sigma
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
|
||||
class WanVideoVAE38(torch.nn.Module):
|
||||
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
|
||||
|
||||
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
|
||||
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
|
||||
`latents_std`), so `encode`/`decode` here apply the same standardization the
|
||||
Wan reference uses — `(latents - mean) / std` — done in fp32 for stability.
|
||||
`encode` uses the deterministic posterior mode, matching the original VAE
|
||||
which returned the latent mean `mu`.
|
||||
"""
|
||||
|
||||
upsampling_factor = 16
|
||||
temporal_downsample_factor = 4
|
||||
z_dim = 48
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: AutoencoderKLWan,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
|
||||
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
|
||||
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
|
||||
self.vae = pretrained.to(device=device, dtype=dtype)
|
||||
|
||||
# Read the standardization stats from the VAE's own config (diffusers populates
|
||||
# these from vae/config.json) — single source of truth, no local copy. diffusers'
|
||||
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
|
||||
# Non-persistent: kept out of state_dict.
|
||||
self.register_buffer(
|
||||
"latents_mean",
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"latents_std",
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
|
||||
param = next(self.vae.parameters())
|
||||
return param.device, param.dtype
|
||||
|
||||
def encode(
|
||||
self,
|
||||
videos: list[torch.Tensor] | torch.Tensor,
|
||||
device: str | torch.device | None = None,
|
||||
tiled: bool = False,
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
|
||||
if isinstance(videos, (list, tuple)):
|
||||
videos = torch.stack(list(videos))
|
||||
dev, dtype = self._device_dtype()
|
||||
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
|
||||
mean = self.latents_mean.float().to(mu.device)
|
||||
std = self.latents_std.float().to(mu.device)
|
||||
return (mu - mean) / std
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
device: str | torch.device | None = None,
|
||||
tiled: bool = False,
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
|
||||
if isinstance(hidden_states, (list, tuple)):
|
||||
hidden_states = torch.stack(list(hidden_states))
|
||||
dev, dtype = self._device_dtype()
|
||||
z = hidden_states.float()
|
||||
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
|
||||
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
|
||||
return out.float().clamp_(-1.0, 1.0)
|
||||
|
||||
|
||||
__all__ = ["WanVideoVAE38"]
|
||||
@@ -0,0 +1,172 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
|
||||
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
|
||||
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
|
||||
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
|
||||
WAN_T5_TOKENIZER = "google/umt5-xxl"
|
||||
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
||||
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
|
||||
|
||||
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
|
||||
the call in `FastWAM.encode_prompt`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: torch.nn.Module,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
|
||||
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
|
||||
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
|
||||
self.model = pretrained.to(device=device, dtype=dtype)
|
||||
self.dim = int(self.model.config.d_model)
|
||||
|
||||
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
|
||||
|
||||
|
||||
class WanTokenizer:
|
||||
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
|
||||
FastWAM call site expects."""
|
||||
|
||||
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sequence: str | Sequence[str],
|
||||
return_mask: bool = False,
|
||||
add_special_tokens: bool = True,
|
||||
**_: Any,
|
||||
):
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
out = self.tokenizer(
|
||||
list(sequence),
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.seq_len,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if return_mask:
|
||||
return out.input_ids, out.attention_mask
|
||||
return out.input_ids
|
||||
|
||||
|
||||
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
|
||||
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
|
||||
|
||||
|
||||
def resolve_wan_dit_paths(
|
||||
model_id_or_path: str | Path,
|
||||
*,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
) -> list[Path]:
|
||||
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
|
||||
path = Path(model_id_or_path).expanduser()
|
||||
if path.is_dir():
|
||||
return sorted(path.glob(WAN_DIT_PATTERN))
|
||||
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id=str(model_id_or_path),
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
allow_patterns=[WAN_DIT_PATTERN],
|
||||
)
|
||||
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
|
||||
|
||||
|
||||
def load_wan_video_dit(
|
||||
paths: list[str | Path],
|
||||
*,
|
||||
dit_config: dict[str, Any],
|
||||
torch_dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> WanVideoDiT:
|
||||
model = WanVideoDiT(**dit_config)
|
||||
state_dict = _read_wan_dit_safetensors(paths)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model.to(device=device, dtype=torch_dtype)
|
||||
|
||||
|
||||
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
for path in paths:
|
||||
state_dict.update(load_file(str(path), device="cpu"))
|
||||
return state_dict
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WAN22_DIFFUSERS_MODEL_ID",
|
||||
"WAN_DIT_PATTERN",
|
||||
"WAN_T5_TOKENIZER",
|
||||
"WanTextEncoder",
|
||||
"WanTokenizer",
|
||||
"build_wan_tokenizer",
|
||||
"load_pretrained_wan_text_encoder",
|
||||
"load_pretrained_wan_vae",
|
||||
"load_wan_video_dit",
|
||||
"resolve_wan_dit_paths",
|
||||
]
|
||||
@@ -0,0 +1,813 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
from einops import rearrange
|
||||
|
||||
from .wan.modules.model import (
|
||||
WanAttentionBlock,
|
||||
WanLayerNorm,
|
||||
WanModel,
|
||||
WanRMSNorm,
|
||||
rope_apply,
|
||||
rope_params,
|
||||
sinusoidal_embedding_1d,
|
||||
)
|
||||
from .wan.utils.fm_solvers import get_sampling_sigmas
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
|
||||
def fastwam_masked_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
num_heads: int,
|
||||
ctx_mask: torch.Tensor | None = None,
|
||||
fp32_attention: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""FastWAM masked attention wrapper for MoT masks and CPU test coverage.
|
||||
|
||||
The official Wan attention implementation is still used as the source of
|
||||
the projection/norm modules. This wrapper only replaces the final attention
|
||||
kernel because FastWAM needs explicit boolean masks for video/action MoT
|
||||
routing, while the upstream FlashAttention path accepts sequence lengths
|
||||
but not arbitrary [query, key] masks.
|
||||
"""
|
||||
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
if fp32_attention:
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
v = v.float()
|
||||
else:
|
||||
q = q.to(dtype=v.dtype)
|
||||
k = k.to(dtype=v.dtype)
|
||||
x = functional.scaled_dot_product_attention(q, k, v, attn_mask=ctx_mask)
|
||||
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
|
||||
return get_sampling_sigmas(num_inference_steps, shift)
|
||||
|
||||
|
||||
class WanContinuousFlowMatchScheduler:
|
||||
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10):
|
||||
if num_train_timesteps <= 0:
|
||||
raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}")
|
||||
if shift <= 0:
|
||||
raise ValueError(f"`shift` must be positive, got {shift}")
|
||||
self.num_train_timesteps = int(num_train_timesteps)
|
||||
self.shift = float(shift)
|
||||
self.eps = float(eps)
|
||||
self._y_min, self._weight_norm_const = self._precompute_training_weight_stats()
|
||||
|
||||
@staticmethod
|
||||
def _phi(u: torch.Tensor, shift: float) -> torch.Tensor:
|
||||
return shift * u / (1.0 + (shift - 1.0) * u)
|
||||
|
||||
def _precompute_training_weight_stats(self) -> tuple[float, float]:
|
||||
steps = self.num_train_timesteps
|
||||
u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1]
|
||||
t_grid = self._phi(u_grid, self.shift) * float(steps)
|
||||
y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2)
|
||||
y_min = float(y_grid.min().item())
|
||||
y_shifted_grid = y_grid - y_min
|
||||
norm_const = float(y_shifted_grid.mean().item())
|
||||
return y_min, norm_const
|
||||
|
||||
def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"`batch_size` must be positive, got {batch_size}")
|
||||
u = torch.rand((batch_size,), device=device, dtype=torch.float32)
|
||||
sigma = self._phi(u, self.shift)
|
||||
timestep = sigma * float(self.num_train_timesteps)
|
||||
return timestep.to(dtype=dtype)
|
||||
|
||||
def training_weight(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
t = timestep.to(dtype=torch.float32)
|
||||
steps = float(self.num_train_timesteps)
|
||||
y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2)
|
||||
y_shifted = y - self._y_min
|
||||
weight = y_shifted / (self._weight_norm_const + self.eps)
|
||||
if weight.numel() == 1:
|
||||
return weight.reshape(())
|
||||
return weight
|
||||
|
||||
def add_noise(
|
||||
self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
sigma = (timestep / float(self.num_train_timesteps)).to(
|
||||
original_samples.device, dtype=original_samples.dtype
|
||||
)
|
||||
if sigma.ndim == 0:
|
||||
return (1 - sigma) * original_samples + sigma * noise
|
||||
sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1)))
|
||||
return (1 - sigma) * original_samples + sigma * noise
|
||||
|
||||
@staticmethod
|
||||
def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
||||
del timestep
|
||||
return noise - sample
|
||||
|
||||
def build_inference_schedule(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
shift_override: float | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}")
|
||||
shift = self.shift if shift_override is None else float(shift_override)
|
||||
if shift <= 0:
|
||||
raise ValueError(f"`shift` must be positive, got {shift}")
|
||||
|
||||
sigma_steps = torch.as_tensor(
|
||||
_get_wan_sampling_sigmas(num_inference_steps, shift),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
timesteps = sigma_steps * float(self.num_train_timesteps)
|
||||
sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)])
|
||||
deltas = sigma_next - sigma_steps
|
||||
return timesteps.to(dtype=dtype), deltas.to(dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
|
||||
delta = delta.to(sample.device, dtype=sample.dtype)
|
||||
if delta.ndim == 0:
|
||||
return sample + model_output * delta
|
||||
delta = delta.view(-1, *([1] * (sample.ndim - 1)))
|
||||
return sample + model_output * delta
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
return rope_params(end, dim, theta)
|
||||
|
||||
|
||||
def apply_dense_rope(x: torch.Tensor, freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float32).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
def _linear_input(linear: nn.Linear, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(dtype=linear.weight.dtype)
|
||||
|
||||
|
||||
def _wan_layer_norm(norm: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(norm, WanLayerNorm) and norm.weight is not None:
|
||||
weight = norm.weight.float()
|
||||
bias = norm.bias.float() if norm.bias is not None else None
|
||||
return functional.layer_norm(x.float(), norm.normalized_shape, weight, bias, norm.eps).to(
|
||||
dtype=x.dtype
|
||||
)
|
||||
return norm(x)
|
||||
|
||||
|
||||
def create_group_causal_attn_mask(
|
||||
num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal"
|
||||
) -> torch.Tensor:
|
||||
if mode not in ["causal", "group_diagonal"]:
|
||||
raise ValueError(f"`mode` must be 'causal' or 'group_diagonal', got {mode}.")
|
||||
if num_temporal_groups <= 0:
|
||||
raise ValueError(f"`num_temporal_groups` must be positive, got {num_temporal_groups}.")
|
||||
if num_query_per_group <= 0:
|
||||
raise ValueError(f"`num_query_per_group` must be positive, got {num_query_per_group}.")
|
||||
if num_key_per_group <= 0:
|
||||
raise ValueError(f"`num_key_per_group` must be positive, got {num_key_per_group}.")
|
||||
|
||||
total_num_query_tokens = num_temporal_groups * num_query_per_group
|
||||
total_num_key_tokens = num_temporal_groups * num_key_per_group
|
||||
query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group).unsqueeze(1)
|
||||
key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group).unsqueeze(0)
|
||||
|
||||
if mode == "causal":
|
||||
attn_mask = query_time_indices >= key_time_indices
|
||||
else:
|
||||
attn_mask = query_time_indices == key_time_indices
|
||||
|
||||
if attn_mask.shape != (total_num_query_tokens, total_num_key_tokens):
|
||||
raise RuntimeError("Attention mask shape mismatch.")
|
||||
return attn_mask
|
||||
|
||||
|
||||
class FastWAMAttentionBlock(WanAttentionBlock):
|
||||
"""Wan attention block with FastWAM's arbitrary boolean mask support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
attn_head_dim: int,
|
||||
num_heads: int,
|
||||
ffn_dim: int,
|
||||
eps: float = 1e-6,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
attention_dim = attn_head_dim * num_heads
|
||||
if hidden_dim == attention_dim:
|
||||
super().__init__(
|
||||
dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=eps,
|
||||
)
|
||||
else:
|
||||
nn.Module.__init__(self)
|
||||
self.dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = (-1, -1)
|
||||
self.qk_norm = True
|
||||
self.cross_attn_norm = True
|
||||
self.eps = eps
|
||||
self.norm1 = WanLayerNorm(hidden_dim, eps)
|
||||
self.self_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
|
||||
self.norm3 = WanLayerNorm(hidden_dim, eps, elementwise_affine=True)
|
||||
self.cross_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
|
||||
self.norm2 = WanLayerNorm(hidden_dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(hidden_dim, ffn_dim),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(ffn_dim, hidden_dim),
|
||||
)
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, hidden_dim) / hidden_dim**0.5)
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.fp32_attention = bool(fp32_attention)
|
||||
|
||||
@staticmethod
|
||||
def split_modulation(block, t_mod: torch.Tensor):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
|
||||
base_mod = block.modulation.to(dtype=t_mod.dtype, device=t_mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (base_mod + t_mod).chunk(
|
||||
6, dim=chunk_dim
|
||||
)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2),
|
||||
scale_msa.squeeze(2),
|
||||
gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2),
|
||||
scale_mlp.squeeze(2),
|
||||
gate_mlp.squeeze(2),
|
||||
)
|
||||
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
def project_self_attention(
|
||||
self, x: torch.Tensor, freqs: torch.Tensor | dict[str, torch.Tensor]
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.self_attn.norm_q(self.self_attn.q(x))
|
||||
k = self.self_attn.norm_k(self.self_attn.k(x))
|
||||
v = self.self_attn.v(x)
|
||||
if isinstance(freqs, dict):
|
||||
b, s = x.shape[:2]
|
||||
q = rope_apply(
|
||||
q.view(b, s, self.num_heads, self.attn_head_dim),
|
||||
freqs["grid_sizes"],
|
||||
freqs["freqs"],
|
||||
).flatten(2)
|
||||
k = rope_apply(
|
||||
k.view(b, s, self.num_heads, self.attn_head_dim),
|
||||
freqs["grid_sizes"],
|
||||
freqs["freqs"],
|
||||
).flatten(2)
|
||||
else:
|
||||
q = apply_dense_rope(q, freqs, self.num_heads)
|
||||
k = apply_dense_rope(k, freqs, self.num_heads)
|
||||
return q, k, v
|
||||
|
||||
def apply_cross_attention(
|
||||
self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if context_mask is not None and context_mask.dim() == 3:
|
||||
context_mask = context_mask.unsqueeze(1)
|
||||
attn = self.cross_attn
|
||||
b, n, d = x.size(0), attn.num_heads, attn.head_dim
|
||||
q = attn.norm_q(attn.q(x)).view(b, -1, n * d)
|
||||
k = attn.norm_k(attn.k(context)).view(b, -1, n * d)
|
||||
v = attn.v(context).view(b, -1, n * d)
|
||||
x = fastwam_masked_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=n,
|
||||
ctx_mask=context_mask,
|
||||
fp32_attention=self.fp32_attention,
|
||||
)
|
||||
return attn.o(_linear_input(attn.o, x))
|
||||
|
||||
def project_self_attention_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_attn.o(_linear_input(self.self_attn.o, x))
|
||||
|
||||
def apply_norm1(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm1, x)
|
||||
|
||||
def apply_norm2(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm2, x)
|
||||
|
||||
def apply_norm3(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm3, x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
t_mod: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
self_attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.split_modulation(self, t_mod)
|
||||
residual_x = x
|
||||
attn_input = modulate(self.apply_norm1(x), shift_msa, scale_msa)
|
||||
q, k, v = self.project_self_attention(attn_input, freqs)
|
||||
y = fastwam_masked_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=self.num_heads,
|
||||
ctx_mask=self_attn_mask,
|
||||
fp32_attention=self.fp32_attention,
|
||||
)
|
||||
x = residual_x + gate_msa * self.project_self_attention_output(y)
|
||||
x = x + self.apply_cross_attention(self.apply_norm3(x), context, context_mask=context_mask)
|
||||
mlp_input = modulate(self.apply_norm2(x), shift_mlp, scale_mlp)
|
||||
return x + gate_mlp * self.ffn(mlp_input)
|
||||
|
||||
|
||||
class _FastWAMProjectedAttention(nn.Module):
|
||||
def __init__(self, hidden_dim: int, attention_dim: int, num_heads: int, eps: float):
|
||||
super().__init__()
|
||||
self.dim = hidden_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = attention_dim // num_heads
|
||||
self.q = nn.Linear(hidden_dim, attention_dim)
|
||||
self.k = nn.Linear(hidden_dim, attention_dim)
|
||||
self.v = nn.Linear(hidden_dim, attention_dim)
|
||||
self.o = nn.Linear(attention_dim, hidden_dim)
|
||||
self.norm_q = WanRMSNorm(attention_dim, eps=eps)
|
||||
self.norm_k = WanRMSNorm(attention_dim, eps=eps)
|
||||
|
||||
|
||||
class WanVideoDiT(WanModel):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: tuple[int, int, int],
|
||||
num_heads: int,
|
||||
attn_head_dim: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool = False,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
separated_timestep: bool = False,
|
||||
require_vae_embedding: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
action_conditioned: bool = False,
|
||||
action_dim: int = 7,
|
||||
action_group_causal_mask_mode="causal",
|
||||
video_attention_mask_mode: str = "bidirectional",
|
||||
use_gradient_checkpointing: bool = False,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
del in_dim_control_adapter
|
||||
if has_image_input:
|
||||
raise ValueError("FastWAM currently expects Wan2.2 TI2V latents with fused image conditioning.")
|
||||
if has_image_pos_emb:
|
||||
raise ValueError("FastWAM does not support extra image positional embeddings in WanVideoDiT.")
|
||||
if has_ref_conv:
|
||||
raise ValueError("FastWAM does not support reference convolutions in WanVideoDiT.")
|
||||
if add_control_adapter:
|
||||
raise ValueError("FastWAM does not support control adapters in WanVideoDiT.")
|
||||
if require_clip_embedding:
|
||||
raise ValueError("FastWAM does not support CLIP embedding conditioning in WanVideoDiT.")
|
||||
if require_vae_embedding or not fuse_vae_embedding_in_latents:
|
||||
raise ValueError("FastWAM expects VAE conditioning to be fused in latents.")
|
||||
if attn_head_dim != hidden_dim // num_heads:
|
||||
raise ValueError(
|
||||
"`attn_head_dim` must match the upstream Wan head dimension `hidden_dim // num_heads`; "
|
||||
f"got {attn_head_dim} vs {hidden_dim // num_heads}."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_type="ti2v",
|
||||
patch_size=patch_size,
|
||||
text_len=512,
|
||||
in_dim=in_dim,
|
||||
dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
freq_dim=freq_dim,
|
||||
text_dim=text_dim,
|
||||
out_dim=out_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=eps,
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
FastWAMAttentionBlock(
|
||||
hidden_dim=hidden_dim,
|
||||
attn_head_dim=attn_head_dim,
|
||||
num_heads=num_heads,
|
||||
ffn_dim=ffn_dim,
|
||||
eps=eps,
|
||||
fp32_attention=fp32_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.init_weights()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.separated_timestep = separated_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.video_attention_mask_mode = str(video_attention_mask_mode)
|
||||
self.action_conditioned = action_conditioned
|
||||
self.action_dim = action_dim
|
||||
self.fp32_attention = bool(fp32_attention)
|
||||
|
||||
if self.action_conditioned:
|
||||
self.action_embedding = torch.nn.Linear(action_dim, hidden_dim)
|
||||
self.action_group_causal_mask_mode = action_group_causal_mask_mode
|
||||
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
if self.use_gradient_checkpointing:
|
||||
logger.info(
|
||||
"Using gradient checkpointing for DiT blocks. This will save memory but use more computation."
|
||||
)
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
return self.patch_embedding(x)
|
||||
|
||||
def _validate_forward_inputs(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None,
|
||||
action: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if x.ndim != 5:
|
||||
raise ValueError(f"`latents` must be 5D [B, C, T, H, W], got shape {tuple(x.shape)}")
|
||||
num_latent_frames = x.shape[2]
|
||||
if context.ndim != 3:
|
||||
raise ValueError(f"`context` must be 3D [B, L, D], got shape {tuple(context.shape)}")
|
||||
if timestep.ndim != 1:
|
||||
raise ValueError(f"`timestep` must be 1D [B] or [1], got shape {tuple(timestep.shape)}")
|
||||
if self.action_conditioned:
|
||||
allow_text_only_single_frame = num_latent_frames == 1 and action is None
|
||||
if not allow_text_only_single_frame:
|
||||
if action is None:
|
||||
raise ValueError("Action input is required for action-conditioned model.")
|
||||
if action.ndim != 3:
|
||||
raise ValueError(
|
||||
f"`action` must be 3D [B, action_horizon, action_dim], got shape {tuple(action.shape)}"
|
||||
)
|
||||
if action.shape[2] != self.action_dim:
|
||||
raise ValueError(
|
||||
f"`action` last dimension must be {self.action_dim}, got {action.shape[2]}"
|
||||
)
|
||||
if num_latent_frames <= 1:
|
||||
raise ValueError(
|
||||
f"video length must be > 1 for action-conditioned model, got {num_latent_frames}"
|
||||
)
|
||||
if action.shape[1] % (num_latent_frames - 1) != 0:
|
||||
raise ValueError(
|
||||
"action horizon must be divisible by (num_latent_frames - 1), "
|
||||
f"got action_horizon={action.shape[1]}"
|
||||
)
|
||||
if context_mask is None:
|
||||
context_mask = torch.ones(
|
||||
(context.shape[0], context.shape[1]), dtype=torch.bool, device=context.device
|
||||
)
|
||||
else:
|
||||
if context_mask.ndim != 2:
|
||||
raise ValueError(f"`context_mask` must be 2D [B, L], got shape {tuple(context_mask.shape)}")
|
||||
if context_mask.shape[0] != context.shape[0] or context_mask.shape[1] != context.shape[1]:
|
||||
raise ValueError(
|
||||
"`context_mask` shape must match `context` shape [B, L], "
|
||||
f"got {tuple(context_mask.shape)} vs {tuple(context.shape)}"
|
||||
)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
if batch_size != context.shape[0]:
|
||||
if not self.training and batch_size == 1:
|
||||
x = x.expand(context.shape[0], -1, -1, -1, -1)
|
||||
batch_size = context.shape[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch mismatch between latents and context: {batch_size} vs {context.shape[0]}."
|
||||
)
|
||||
|
||||
if timestep.shape[0] not in (1, batch_size):
|
||||
raise ValueError(
|
||||
f"`timestep` length must be 1 or batch_size({batch_size}), got {timestep.shape[0]}"
|
||||
)
|
||||
if timestep.shape[0] == 1 and batch_size > 1:
|
||||
if self.training:
|
||||
raise ValueError("During training, timestep length must match batch_size.")
|
||||
timestep = timestep.expand(batch_size)
|
||||
return x, timestep, context_mask
|
||||
|
||||
def build_video_to_video_mask(
|
||||
self,
|
||||
video_seq_len: int,
|
||||
video_tokens_per_frame: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if video_seq_len <= 0:
|
||||
raise ValueError(f"`video_seq_len` must be positive, got {video_seq_len}")
|
||||
if video_tokens_per_frame <= 0:
|
||||
raise ValueError(f"`video_tokens_per_frame` must be positive, got {video_tokens_per_frame}")
|
||||
|
||||
if self.video_attention_mask_mode == "bidirectional":
|
||||
return torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
|
||||
|
||||
if self.video_attention_mask_mode == "per_frame_causal":
|
||||
if video_seq_len % video_tokens_per_frame != 0:
|
||||
raise ValueError(
|
||||
"`video_seq_len` must be divisible by `video_tokens_per_frame` in `per_frame_causal` mode, "
|
||||
f"got {video_seq_len} and {video_tokens_per_frame}"
|
||||
)
|
||||
num_video_frames = video_seq_len // video_tokens_per_frame
|
||||
frame_causal = torch.tril(
|
||||
torch.ones((num_video_frames, num_video_frames), dtype=torch.bool, device=device)
|
||||
)
|
||||
return frame_causal.repeat_interleave(video_tokens_per_frame, dim=0).repeat_interleave(
|
||||
video_tokens_per_frame, dim=1
|
||||
)
|
||||
|
||||
if self.video_attention_mask_mode == "first_frame_causal":
|
||||
video_mask = torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
|
||||
first_frame_tokens = min(video_tokens_per_frame, video_seq_len)
|
||||
video_mask[:first_frame_tokens, first_frame_tokens:] = False
|
||||
return video_mask
|
||||
|
||||
raise ValueError(f"Unsupported video attention mask mode: {self.video_attention_mask_mode}")
|
||||
|
||||
def pre_dit(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
action: torch.Tensor | None = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
x, timestep, context_mask = self._validate_forward_inputs(
|
||||
x=x,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
action=action,
|
||||
)
|
||||
model_dtype = self.patch_embedding.weight.dtype
|
||||
x = x.to(dtype=model_dtype)
|
||||
context = context.to(dtype=model_dtype)
|
||||
if action is not None:
|
||||
action = action.to(dtype=model_dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
patch_h = int(self.patch_size[1])
|
||||
patch_w = int(self.patch_size[2])
|
||||
if x.shape[3] % patch_h != 0 or x.shape[4] % patch_w != 0:
|
||||
raise ValueError(
|
||||
"Latent spatial shape must be divisible by DiT patch size, "
|
||||
f"got HxW=({x.shape[3]}, {x.shape[4]}), patch=({patch_h}, {patch_w})"
|
||||
)
|
||||
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
|
||||
|
||||
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
|
||||
raise NotImplementedError(
|
||||
"FastWAM currently requires separated timesteps with fused VAE latents."
|
||||
)
|
||||
|
||||
token_timesteps = torch.ones(
|
||||
(batch_size, x.shape[2], tokens_per_frame),
|
||||
dtype=model_dtype,
|
||||
device=timestep.device,
|
||||
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
|
||||
token_timesteps[:, 0, :] = 0
|
||||
token_timesteps = token_timesteps.reshape(batch_size, -1)
|
||||
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
|
||||
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
|
||||
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
|
||||
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
|
||||
# model_dtype so the bf16 attention blocks are not upcast to fp32.
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
|
||||
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
|
||||
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
|
||||
t_mod = t_mod.to(dtype=model_dtype)
|
||||
|
||||
x = self.patchify(x)
|
||||
f, h, w = x.shape[2:]
|
||||
|
||||
context = self.text_embedding(context)
|
||||
context_len = context.shape[1]
|
||||
if self.action_conditioned and action is not None:
|
||||
action_len = action.shape[1]
|
||||
action_emb = self.action_embedding(action)
|
||||
action_pos_embed = sinusoidal_embedding_1d(
|
||||
self.hidden_dim, torch.arange(action_len, device=action_emb.device)
|
||||
).to(dtype=action_emb.dtype)
|
||||
action_emb = action_emb + action_pos_embed.unsqueeze(0)
|
||||
context = torch.cat([context, action_emb], dim=1)
|
||||
|
||||
num_temporal_groups = f - 1
|
||||
if num_temporal_groups <= 0:
|
||||
raise ValueError(
|
||||
"Action-conditioned context mask requires at least 2 latent frames when `action` is provided."
|
||||
)
|
||||
if action_emb.shape[1] % num_temporal_groups != 0:
|
||||
raise ValueError(
|
||||
f"Action embedding length {action_emb.shape[1]} must be divisible by "
|
||||
f"number of temporal groups {num_temporal_groups}"
|
||||
)
|
||||
action_group_mask = create_group_causal_attn_mask(
|
||||
num_temporal_groups=num_temporal_groups,
|
||||
num_query_per_group=tokens_per_frame,
|
||||
num_key_per_group=action_len // num_temporal_groups,
|
||||
mode=self.action_group_causal_mask_mode,
|
||||
).to(context.device)
|
||||
|
||||
seq_len = f * h * w
|
||||
final_context_mask = torch.zeros(
|
||||
(batch_size, seq_len, context.shape[1]), dtype=torch.bool, device=context.device
|
||||
)
|
||||
final_context_mask[:, :, :context_len] = context_mask.unsqueeze(1).expand(-1, seq_len, -1)
|
||||
final_context_mask[:, tokens_per_frame:, context_len:] = action_group_mask.unsqueeze(0).expand(
|
||||
batch_size, -1, -1
|
||||
)
|
||||
context_mask = final_context_mask
|
||||
elif self.action_conditioned and action is None:
|
||||
if f != 1:
|
||||
raise ValueError(
|
||||
"Action-conditioned model requires `action` unless running single-frame text-only mode "
|
||||
"with num_latent_frames=1."
|
||||
)
|
||||
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
|
||||
else:
|
||||
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
|
||||
|
||||
x_tokens = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
|
||||
grid_sizes = torch.tensor([[f, h, w]] * batch_size, dtype=torch.long, device=x_tokens.device)
|
||||
freqs = {"grid_sizes": grid_sizes, "freqs": self.freqs.to(x_tokens.device)}
|
||||
|
||||
return {
|
||||
"tokens": x_tokens,
|
||||
"freqs": freqs,
|
||||
"t": t,
|
||||
"t_mod": t_mod,
|
||||
"context": context,
|
||||
"context_mask": context_mask,
|
||||
"meta": {
|
||||
"grid_sizes": grid_sizes,
|
||||
"tokens_per_frame": tokens_per_frame,
|
||||
"batch_size": batch_size,
|
||||
},
|
||||
}
|
||||
|
||||
def post_dit(self, x_tokens: torch.Tensor, pre_state: dict[str, Any]) -> torch.Tensor:
|
||||
x = self.head(x_tokens, pre_state["t"])
|
||||
return torch.stack(super().unpatchify(x, pre_state["meta"]["grid_sizes"]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
action: torch.Tensor | None = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
pre_state = self.pre_dit(
|
||||
x=x,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
action=action,
|
||||
fuse_vae_embedding_in_latents=fuse_vae_embedding_in_latents,
|
||||
)
|
||||
x_tokens = pre_state["tokens"]
|
||||
context_emb = pre_state["context"]
|
||||
t_mod = pre_state["t_mod"]
|
||||
freqs = pre_state["freqs"]
|
||||
context_attn_mask = pre_state["context_mask"]
|
||||
self_attn_mask = (
|
||||
self.build_video_to_video_mask(
|
||||
video_seq_len=x_tokens.shape[1],
|
||||
video_tokens_per_frame=int(pre_state["meta"]["tokens_per_frame"]),
|
||||
device=x_tokens.device,
|
||||
)
|
||||
if self.video_attention_mask_mode != "bidirectional"
|
||||
else None
|
||||
)
|
||||
|
||||
for block in self.blocks:
|
||||
if self.use_gradient_checkpointing:
|
||||
x_tokens = gradient_checkpoint_forward(
|
||||
block,
|
||||
self.use_gradient_checkpointing,
|
||||
x_tokens,
|
||||
context_emb,
|
||||
t_mod,
|
||||
freqs,
|
||||
context_mask=context_attn_mask,
|
||||
self_attn_mask=self_attn_mask,
|
||||
)
|
||||
else:
|
||||
x_tokens = block(
|
||||
x_tokens,
|
||||
context_emb,
|
||||
t_mod,
|
||||
freqs,
|
||||
context_mask=context_attn_mask,
|
||||
self_attn_mask=self_attn_mask,
|
||||
)
|
||||
|
||||
return self.post_dit(x_tokens, pre_state)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FastWAMAttentionBlock",
|
||||
"WanContinuousFlowMatchScheduler",
|
||||
"WanVideoDiT",
|
||||
"apply_dense_rope",
|
||||
"create_group_causal_attn_mask",
|
||||
"fastwam_masked_attention",
|
||||
"gradient_checkpoint_forward",
|
||||
"modulate",
|
||||
"precompute_freqs_cis",
|
||||
"sinusoidal_embedding_1d",
|
||||
]
|
||||
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -387,21 +384,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -426,16 +416,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
|
||||
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
|
||||
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
|
||||
# data while decoding all of it). Batches are moved to the device manually in the loop below.
|
||||
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
|
||||
else:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -475,9 +458,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
|
||||
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
|
||||
@@ -1,150 +0,0 @@
|
||||
"""Acceptance tests for manifest byte-index sidecars.
|
||||
|
||||
Run on a compute node (not login-node):
|
||||
|
||||
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
|
||||
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
|
||||
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("torchcodec")
|
||||
|
||||
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
MAX_EPISODES = 64
|
||||
|
||||
COMPUTE_NODE = pytest.mark.skipif(
|
||||
"login" in socket.gethostname(),
|
||||
reason="run on compute node via srun (see module docstring), not login-node",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def byte_index_dir(tmp_path_factory):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
out = tmp_path_factory.mktemp("byte_index")
|
||||
meta = LeRobotDatasetMetadata(REPO, revision=REV)
|
||||
files, episodes, _ = build_byte_index_tables(
|
||||
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
|
||||
)
|
||||
write_byte_index(out, files, episodes, None, merge_existing=False)
|
||||
return out, meta
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_index_load_fast_and_small(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
|
||||
out, meta = byte_index_dir
|
||||
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
assert index.load_time_s < 1.0
|
||||
assert index.resident_bytes < 1_000_000_000
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_tight_fetch_under_25mb(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
stats = cache.stats.stats_dict()
|
||||
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_in_memory_build_matches_parquet(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
out, meta = byte_index_dir
|
||||
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
for cam in meta.video_keys:
|
||||
a = disk.lookup(ep, cam)
|
||||
b = mem.lookup(ep, cam)
|
||||
assert a.mdat_offset == b.mdat_offset
|
||||
assert a.mdat_length == b.mdat_length
|
||||
assert abs(a.first_pts - b.first_pts) < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_custom_frame_mappings_available(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cam = meta.video_keys[0]
|
||||
ep = MAX_EPISODES // 2
|
||||
payload = index.custom_frame_mappings(ep, cam)
|
||||
assert payload is not None
|
||||
data = json.loads(payload)
|
||||
assert len(data["frames"]) > 10
|
||||
assert any(f["key_frame"] for f in data["frames"])
|
||||
assert all("pts" in f and "duration" in f for f in data["frames"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_metadata_skip_decoder_init(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
assert dec.metadata.num_frames is not None
|
||||
assert dec.metadata.num_frames > 0
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_sparse_decode_produces_frames(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
assert frame.numel() > 0
|
||||
assert float(frame.float().std()) > 1.0
|
||||
@@ -114,30 +114,6 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -24,6 +25,52 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -73,9 +120,10 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -90,17 +138,25 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -108,11 +164,15 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -127,21 +187,31 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -218,11 +288,6 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
||||
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
||||
check = float(left) == float(right)
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -1,430 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
|
||||
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=2,
|
||||
frame_shuffle_buffer_size=8,
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
consumed = [int(next(it)["index"]) for _ in range(30)]
|
||||
state = ds.state_dict()
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
rest = [int(frame["index"]) for frame in resumed_ds]
|
||||
|
||||
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
|
||||
# in-flight buffer contents are skipped on resume (documented datasets behavior):
|
||||
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
|
||||
covered = len(set(consumed) | set(rest))
|
||||
max_in_flight = 2 * 30 + 8
|
||||
assert covered >= total_frames - max_in_flight
|
||||
assert covered + len(consumed) >= total_frames - max_in_flight
|
||||
|
||||
|
||||
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
|
||||
import datasets as hf_datasets
|
||||
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
|
||||
|
||||
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
|
||||
|
||||
|
||||
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
|
||||
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
|
||||
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-reshard"
|
||||
total_episodes = 3
|
||||
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
|
||||
# num_shards can reach total_episodes is per-row-group resharding.
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=90,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
file_to_eps = ds._episode_files()
|
||||
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
|
||||
for (chunk_idx, file_idx), eps in file_to_eps.items():
|
||||
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
|
||||
|
||||
assert ds.num_shards == total_episodes
|
||||
|
||||
|
||||
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
|
||||
a single batch should already span most of the live episodes."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-frac"
|
||||
total_episodes = 8
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=240,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
episode_pool_size=total_episodes,
|
||||
max_buffer_input_shards=total_episodes,
|
||||
)
|
||||
assert ds.max_buffer_input_shards == total_episodes
|
||||
|
||||
batch = 32
|
||||
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
|
||||
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
|
||||
|
||||
|
||||
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
|
||||
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
|
||||
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
|
||||
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
|
||||
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
|
||||
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
|
||||
)
|
||||
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
|
||||
|
||||
|
||||
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
|
||||
stream the whole dataset; the guard must raise instead of silently degrading."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-divis"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
with pytest.raises(ValueError, match="not divisible by world_size"):
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
|
||||
)
|
||||
|
||||
# Bypassing the guard downgrades it to a warning (no raise).
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=3,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
validate_row_groups=False,
|
||||
)
|
||||
assert ds.num_shards == 3
|
||||
Vendored
+3
-22
@@ -17,7 +17,6 @@ from pathlib import Path
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -36,24 +35,6 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
|
||||
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
|
||||
|
||||
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
|
||||
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
|
||||
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
|
||||
single row group instead.
|
||||
"""
|
||||
table = hf_dataset.with_format("arrow")[:]
|
||||
episode_index = np.asarray(hf_dataset["episode_index"])
|
||||
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
|
||||
starts = [0, *boundaries.tolist()]
|
||||
ends = [*boundaries.tolist(), len(episode_index)]
|
||||
with pq.ParquetWriter(str(path), table.schema) as writer:
|
||||
for start, end in zip(starts, ends, strict=True):
|
||||
writer.write_table(table.slice(start, end - start))
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
hf_dataset: Dataset,
|
||||
local_dir: Path,
|
||||
@@ -86,7 +67,7 @@ def write_hf_dataset(
|
||||
# If the dataset is small enough, write it to a single file.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_to_parquet_one_row_group_per_episode(hf_dataset, path)
|
||||
hf_dataset.to_parquet(path)
|
||||
return
|
||||
|
||||
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
|
||||
@@ -133,8 +114,8 @@ def write_hf_dataset(
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the shard to a Parquet file (one row group per episode).
|
||||
_to_parquet_one_row_group_per_episode(dataset_shard, path)
|
||||
# Write the shard to a Parquet file.
|
||||
dataset_shard.to_parquet(path)
|
||||
|
||||
# Update chunk and file indices for the next iteration.
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
|
||||
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class FakeFastWAMCore(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dit = nn.Linear(2, 2)
|
||||
|
||||
def training_loss(self, sample):
|
||||
assert sample["video"].ndim == 5
|
||||
assert sample["context"].ndim == 3
|
||||
return sample[ACTION].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0}
|
||||
|
||||
def infer_action(self, **kwargs):
|
||||
return {"action": torch.ones(1, kwargs["action_horizon"], 3)}
|
||||
|
||||
|
||||
def test_fastwam_is_registered_and_publicly_exported():
|
||||
cfg = make_policy_config(
|
||||
"fastwam",
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(cfg, FastWAMConfig)
|
||||
assert cfg.type == "fastwam"
|
||||
assert get_policy_class("fastwam") is FastWAMPolicy
|
||||
|
||||
|
||||
def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
|
||||
cfg = FastWAMConfig()
|
||||
cfg.save_pretrained(tmp_path)
|
||||
saved = json.loads((tmp_path / "config.json").read_text())
|
||||
|
||||
assert saved["pretrained_path"] is None
|
||||
assert cfg.image_features["observation.images.image"].type == FeatureType.VISUAL
|
||||
assert cfg.action_feature.shape == (7,)
|
||||
assert cfg.robot_state_feature.shape == (8,)
|
||||
with pytest.raises(ValueError, match="image feature"):
|
||||
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
|
||||
with pytest.raises(ValueError, match="tokenizer_model_id"):
|
||||
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
|
||||
|
||||
|
||||
def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_path):
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(2, 2),
|
||||
device="cpu",
|
||||
toggle_action_dimensions=[-1],
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 2, 2)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
base_model_id=None,
|
||||
)
|
||||
dataset_stats = {
|
||||
"observation.images.image": {
|
||||
"mean": torch.full((3, 1, 1), 0.2),
|
||||
"std": torch.full((3, 1, 1), 0.1),
|
||||
},
|
||||
OBS_STATE: {
|
||||
"mean": torch.tensor([1.0, 3.0]),
|
||||
"std": torch.tensor([2.0, 4.0]),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(3),
|
||||
"std": torch.ones(3),
|
||||
},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)
|
||||
processed = preprocessor(
|
||||
{
|
||||
"observation.images.image": torch.tensor(
|
||||
[
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
]
|
||||
),
|
||||
OBS_STATE: torch.tensor([3.0, 7.0]),
|
||||
}
|
||||
)
|
||||
preprocessor.save_pretrained(tmp_path, config_filename="policy_preprocessor.json")
|
||||
postprocessor.save_pretrained(tmp_path, config_filename="policy_postprocessor.json")
|
||||
_, loaded_postprocessor = make_pre_post_processors(cfg, pretrained_path=str(tmp_path))
|
||||
|
||||
expected_image = torch.tensor(
|
||||
[[[[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]]]]
|
||||
)
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
assert torch.allclose(processed["observation.images.image"], expected_image)
|
||||
assert torch.allclose(processed[OBS_STATE], torch.tensor([[1.0, 1.0]]))
|
||||
assert torch.equal(dataset_stats["observation.images.image"]["mean"], torch.full((3, 1, 1), 0.2))
|
||||
assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in loaded_postprocessor.steps)
|
||||
assert torch.equal(
|
||||
loaded_postprocessor(torch.tensor([[0.25, 0.5, 1.0]])), torch.tensor([[0.25, 0.5, -1.0]])
|
||||
)
|
||||
|
||||
|
||||
def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
captured = []
|
||||
|
||||
class CapturingCore(FakeFastWAMCore):
|
||||
def infer_action(self, **kwargs):
|
||||
captured.append(
|
||||
{
|
||||
"image_shape": tuple(kwargs["input_image"].shape),
|
||||
"proprio_shape": tuple(kwargs["proprio"].shape),
|
||||
"prompt": kwargs["prompt"],
|
||||
}
|
||||
)
|
||||
return {"action": torch.full((1, kwargs["action_horizon"], 3), float(len(captured)))}
|
||||
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore())
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(16, 16),
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
base_model_id=None,
|
||||
)
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
loss, metrics = policy.forward(
|
||||
{
|
||||
"observation.images.image": torch.zeros(1, 3, 16, 16),
|
||||
OBS_STATE: torch.zeros(1, 2),
|
||||
ACTION: torch.zeros(1, 4, 3),
|
||||
"context": torch.zeros(1, 5, 4096),
|
||||
"context_mask": torch.ones(1, 5, dtype=torch.bool),
|
||||
}
|
||||
)
|
||||
action = policy.predict_action_chunk(
|
||||
{
|
||||
"observation.images.image": torch.stack(
|
||||
[
|
||||
torch.zeros(3, 16, 16),
|
||||
torch.ones(3, 16, 16),
|
||||
]
|
||||
),
|
||||
OBS_STATE: torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
|
||||
"task": ["task 0", "task 1"],
|
||||
}
|
||||
)
|
||||
|
||||
assert loss.item() == 1.0
|
||||
assert metrics["loss_action"] == 1.0
|
||||
assert action.shape == (2, 4, 3)
|
||||
assert action[:, 0, 0].tolist() == [1.0, 2.0]
|
||||
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
|
||||
assert [item["proprio_shape"] for item in captured] == [(1, 2), (1, 2)]
|
||||
assert [item["prompt"] for item in captured] == [
|
||||
cfg.prompt_template.format(task="task 0"),
|
||||
cfg.prompt_template.format(task="task 1"),
|
||||
]
|
||||
|
||||
|
||||
class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
"""Fake core mirroring the real one: frozen VAE / text encoder held as
|
||||
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
|
||||
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "vae", nn.Linear(2, 2))
|
||||
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
super()._apply(fn, *args, **kwargs)
|
||||
self.vae._apply(fn)
|
||||
self.text_encoder._apply(fn)
|
||||
return self
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
with torch.no_grad():
|
||||
core.dit.weight.fill_(0.5)
|
||||
return core
|
||||
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
|
||||
|
||||
reference = FastWAMPolicy(cfg)
|
||||
with torch.no_grad():
|
||||
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
|
||||
reference.save_pretrained(tmp_path)
|
||||
|
||||
# Building from Wan2.2 must never happen on a checkpoint load.
|
||||
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
|
||||
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
|
||||
fail_if_wan_pretrained_is_loaded,
|
||||
)
|
||||
|
||||
policy = FastWAMPolicy.from_pretrained(tmp_path)
|
||||
|
||||
assert isinstance(policy.model, CoreWithFrozenComponents)
|
||||
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
|
||||
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
|
||||
|
||||
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
save_dir = tmp_path / "saved"
|
||||
policy.save_pretrained(save_dir)
|
||||
|
||||
assert (save_dir / "model.safetensors").is_file()
|
||||
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
|
||||
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
|
||||
assert not (save_dir / "google").exists()
|
||||
|
||||
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
|
||||
keys = set(f.keys())
|
||||
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
|
||||
# encoder are excluded (loaded from the diffusers/transformers repos at init).
|
||||
assert any(key.startswith("model.dit.") for key in keys)
|
||||
assert not any(key.startswith("model.vae.") for key in keys)
|
||||
assert not any(key.startswith("model.text_encoder.") for key in keys)
|
||||
|
||||
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
|
||||
sd = policy.state_dict()
|
||||
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
|
||||
param_names = [n for n, _ in policy.named_parameters()]
|
||||
assert not any("vae" in n or "text_encoder" in n for n in param_names)
|
||||
|
||||
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
|
||||
# for device on a CPU box), so they never strand off the rest of the model.
|
||||
policy.to(torch.float64)
|
||||
assert policy.model.dit.weight.dtype == torch.float64 # registered
|
||||
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
|
||||
assert policy.model.text_encoder.weight.dtype == torch.float64
|
||||
|
||||
|
||||
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448), base_model_id=None)
|
||||
cfg.save_pretrained(tmp_path)
|
||||
|
||||
loaded = PreTrainedConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.type == "fastwam"
|
||||
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
|
||||
assert loaded.action_feature.shape == (7,)
|
||||
assert loaded.robot_state_feature.shape == (8,)
|
||||
|
||||
|
||||
def test_vae_adapter_empty_build_encode_decode_shapes():
|
||||
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
|
||||
|
||||
Validates the encode/decode contract — 48 latent channels, 16x spatial / 4x
|
||||
temporal compression, list-or-batch input, scaling round-trip — without any
|
||||
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
|
||||
GPU + real-weights verification step.)
|
||||
"""
|
||||
pytest.importorskip("diffusers")
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
|
||||
|
||||
# Production always loads a real pretrained VAE from the diffusers repo; here we
|
||||
# build the same architecture with random weights and dummy standardization stats
|
||||
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
|
||||
# separately, with real weights, on GPU).
|
||||
arch = {
|
||||
"base_dim": 160,
|
||||
"decoder_base_dim": 256,
|
||||
"z_dim": 48,
|
||||
"dim_mult": [1, 2, 4, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_scales": [],
|
||||
"temporal_downsample": [False, True, True],
|
||||
"dropout": 0.0,
|
||||
"is_residual": True,
|
||||
"in_channels": 12,
|
||||
"out_channels": 12,
|
||||
"patch_size": 2,
|
||||
"scale_factor_spatial": 16,
|
||||
"scale_factor_temporal": 4,
|
||||
"clip_output": False,
|
||||
"latents_mean": [0.0] * 48,
|
||||
"latents_std": [1.0] * 48,
|
||||
}
|
||||
raw = AutoencoderKLWan.from_config(arch)
|
||||
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
|
||||
assert vae.z_dim == 48
|
||||
assert vae.upsampling_factor == 16
|
||||
assert vae.temporal_downsample_factor == 4
|
||||
|
||||
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
|
||||
latents = vae.encode(video)
|
||||
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
|
||||
|
||||
decoded = vae.decode(latents)
|
||||
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
|
||||
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
|
||||
|
||||
# list input is accepted and equals the batched path
|
||||
assert torch.equal(vae.encode([video[0]]), latents)
|
||||
@@ -1084,8 +1084,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1102,6 +1102,10 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1760,7 +1764,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1768,14 +1772,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.14"
|
||||
version = "0.1.13"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1785,9 +1789,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1877,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.4"
|
||||
version = "0.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1898,10 +1902,7 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -2829,6 +2830,10 @@ eo1 = [
|
||||
evaluation = [
|
||||
{ name = "av" },
|
||||
]
|
||||
fastwam = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
feetech = [
|
||||
{ name = "deepdiff" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
@@ -3074,7 +3079,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
@@ -3089,12 +3094,12 @@ requires-dist = [
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
@@ -3122,11 +3127,13 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["deepdiff-dep"], marker = "extra == 'hardware'" },
|
||||
{ name = "lerobot", extras = ["dev"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'fastwam'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["fastwam"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
@@ -3195,6 +3202,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'fastwam'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
@@ -3275,7 +3283,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user