mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-11 05:39:49 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fa3eb9fce3 | |||
| 500c91ba92 |
@@ -1,91 +0,0 @@
|
||||
# Streaming dataloading benchmark
|
||||
|
||||
Measures **dataloading only** (no model) for `StreamingLeRobotDataset`: parquet read + video decode +
|
||||
delta windowing + shuffle. A dummy consumer pulls batches and moves them to the device, so the numbers
|
||||
isolate the data pipeline. Use it to compare sources (Hub vs. storage bucket vs. prewarmed bucket),
|
||||
frame modes, and node counts, and to catch p95/p99 video-decode regressions.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--mode sarm --batch_size 64 --num_workers 12 --num_batches 200 \
|
||||
--source hub --out_dir benchmarks/streaming/results
|
||||
```
|
||||
|
||||
Multinode (per-node throughput) goes through Accelerate under SLURM:
|
||||
|
||||
```bash
|
||||
sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
```
|
||||
|
||||
## Matrix
|
||||
|
||||
| Axis | Values |
|
||||
| ---------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Source | `hub` (verify now), `bucket`, `warmed_bucket` (bucket + prewarming; with user's help later) |
|
||||
| Baseline | current `main` `StreamingLeRobotDataset` on Hub streaming |
|
||||
| Nodes | 1 and 2 (per-node throughput should be independent) |
|
||||
| Frame mode | `single` (1 frame, all cameras; target ≥ 120 frames/s/node) · `sarm` (8 steps spaced 1s; target ≥ 320 frames/s/node) |
|
||||
|
||||
`--source` is a label only; the actual source is whatever `--repo_id` / `--root` / `--data_files_root`
|
||||
point at.
|
||||
|
||||
### GPU (NVDEC) decoding
|
||||
|
||||
By default video is decoded on the **CPU** in each DataLoader worker, so throughput is CPU-decode-bound and
|
||||
scales with `--num_workers` (capped by the dataset's `num_shards`). Pass `--video_decode_device cuda` to
|
||||
offload H.264/H.265 decode to the GPU's dedicated **NVDEC** engine, which runs independently of the SMs used
|
||||
for training (see <https://developer.nvidia.com/video-codec-sdk>). This requires a CUDA-enabled torchcodec
|
||||
build, and because CUDA cannot initialize in forked workers the benchmark switches to the `spawn` start
|
||||
method automatically when `--num_workers > 0`.
|
||||
|
||||
```bash
|
||||
# GPU/NVDEC decode, 6 workers, bucket source
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--data_files_root hf://buckets/pepijn223/robocasa-stream \
|
||||
--mode sarm --batch_size 64 --num_workers 6 --num_batches 200 \
|
||||
--video_decode_device cuda --source bucket
|
||||
```
|
||||
|
||||
Caveats with `cuda` + many workers: each worker creates its own CUDA context (VRAM overhead) and NVDEC has a
|
||||
limited number of concurrent decode sessions per GPU; if you hit session/IPC limits, reduce `--num_workers`
|
||||
or compare against `--num_workers 0` (single-process NVDEC, which often saturates the decode engine on its
|
||||
own). Result files include the decode device in their name (`..._w6_cuda.json`).
|
||||
|
||||
> **Codec ⇄ NVDEC compatibility (important).** NVDEC can only decode codecs its hardware supports. LeRobot's
|
||||
> **default video codec is AV1** (`VideoEncoderConfig.vcodec = "libsvtav1"`), so most v3 datasets are
|
||||
> AV1-encoded — and the **A100 and H100 compute GPUs have no AV1 NVDEC decoder**
|
||||
> (per NVIDIA's [decode support matrix](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new));
|
||||
> only Ada (L4/L40/RTX40) and a few Ampere cards (A10/A40/A16) do. On A100/H100, AV1 must be decoded on
|
||||
> **CPU**, or the dataset re-encoded to H.265/H.264 (which those GPUs' NVDEC do support). Run
|
||||
> `diagnose_decode.py --video_decode_device cuda` to check your exact node before relying on `cuda` decode.
|
||||
> A `cuda` torchcodec build also needs an FFmpeg with NVDEC; see
|
||||
> <https://github.com/meta-pytorch/torchcodec#installing-cuda-enabled-torchcodec>.
|
||||
|
||||
Reference data root: bucket sources resolve through `--data_files_root hf://buckets/<owner>/<name>` (metadata
|
||||
still loads from `--repo_id`). The local `single`/`sarm` CPU baselines on this dataset were ~176 / ~212
|
||||
frames/s/node at `--num_workers 3` (3 cameras, fps 20).
|
||||
|
||||
## Metrics emitted (JSON + CSV)
|
||||
|
||||
`frames_per_s_node`, `samples_per_s`, `first_batch_latency_s`, `p50/p95/p99_sample_latency_ms`,
|
||||
`wallclock_s`, and `video_decoder_cache` (`hits`, `misses`, `evictions`, `hit_rate`, `size`). A low
|
||||
cache `hit_rate` with high `p99` is the decoder-thrash signature — raise `--video_decoder_cache_size`
|
||||
or `--buffer_size`, or reduce `num_workers`.
|
||||
|
||||
## Bucket sources & prewarming (manual)
|
||||
|
||||
Prewarming is a **server-side** Hugging Face storage-bucket feature — there is no client script. To
|
||||
benchmark the `warmed_bucket` source:
|
||||
|
||||
1. Attach a storage bucket to the dataset and enable it (see
|
||||
<https://huggingface.co/docs/hub/storage-buckets>). Buckets resolve through `fsspec`, the same as
|
||||
`hf://`, so no code change is needed — point `--repo_id`/`--revision` (or `--root`) at the bucket.
|
||||
2. Enable **prewarming** in the bucket settings and wait for warm-up to complete.
|
||||
3. Run the benchmark with `--source warmed_bucket`. Compare against the cold `--source bucket` and the
|
||||
`--source hub` baseline.
|
||||
|
||||
Manual only — not run in CI.
|
||||
@@ -1,209 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloading-only benchmark for StreamingLeRobotDataset.
|
||||
|
||||
A dummy consumer pulls batches and moves them to the device; no model runs, so the numbers isolate the
|
||||
data pipeline (parquet read + video decode + delta windowing + shuffle). Reports per-node throughput and
|
||||
sample-latency percentiles, plus video-decoder-cache reuse stats, and emits JSON + CSV.
|
||||
|
||||
Frame modes (matching the streaming design targets):
|
||||
- ``single``: one frame, all cameras (target >= 120 frames/s/node).
|
||||
- ``sarm``: an 8-step window spaced 1s (delta over 8s) (target >= 320 frames/s/node).
|
||||
|
||||
Example (stream from the Hub, single node):
|
||||
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 --mode sarm \
|
||||
--batch_size 64 --num_workers 12 --num_batches 200 --out_dir benchmarks/streaming/results
|
||||
|
||||
Distributed / multinode runs go through Accelerate; see ``slurm/benchmark_streaming_robocasa.sh``. Set
|
||||
``--source`` purely for labeling the output (``hub`` / ``bucket`` / ``warmed_bucket``); the actual source
|
||||
is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarming.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, required=True)
|
||||
parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).")
|
||||
parser.add_argument(
|
||||
"--data_files_root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="fsspec root for bulk data/videos, e.g. hf://buckets/<owner>/<name>. Metadata still loads "
|
||||
"from --repo_id on the Hub. Use for bucket / warmed_bucket sources.",
|
||||
)
|
||||
parser.add_argument("--mode", choices=["single", "sarm"], default="single")
|
||||
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--buffer_size", type=int, default=2000)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--video_decode_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Decode device passed to torchcodec. 'cuda' offloads decode to the GPU's NVDEC engine "
|
||||
"(needs a CUDA-enabled torchcodec build). With num_workers>0 this forces the 'spawn' start method.",
|
||||
)
|
||||
parser.add_argument("--num_batches", type=int, default=200)
|
||||
parser.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state stats.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--out_dir", type=str, default="benchmarks/streaming/results")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> StreamingLeRobotDataset:
|
||||
# sarm: an 8-step window spaced 1s => an 8s delta window (the SARM stress case).
|
||||
delta_timestamps = {ACTION: [float(t) for t in range(8)]} if args.mode == "sarm" else None
|
||||
return StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
data_files_root=args.data_files_root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
video_decode_device=args.video_decode_device,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
device = torch.device(args.device)
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
dataset = build_dataset(args, meta)
|
||||
|
||||
gpu_decode = args.video_decode_device.startswith("cuda")
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
# GPU-decoded frames are already on the GPU, so CPU pinning is irrelevant (and pinning CUDA
|
||||
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
|
||||
pin_memory=device.type == "cuda" and not gpu_decode,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
|
||||
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
frames = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None # wall-clock start of the post-warmup measurement window
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
for i, batch in enumerate(loader):
|
||||
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
|
||||
if i == args.warmup_batches:
|
||||
# Start the steady window here; the slow first batch and the prefetch queue it filled are
|
||||
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
frames += args.batch_size
|
||||
t_prev = now
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
# Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those
|
||||
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
cache_stats = dataset.video_decoder_cache_stats()
|
||||
|
||||
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
|
||||
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
|
||||
# never reported green with NaN/zero numbers.
|
||||
if frames == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: measured 0 frames over {args.num_batches} requested batches "
|
||||
f"(cache misses={cache_stats.get('misses', 0)}, hits={cache_stats.get('hits', 0)}). "
|
||||
"The data pipeline yielded no usable batches — inspect worker logs for decode errors. "
|
||||
"Try --num_workers 0 to surface the underlying exception directly."
|
||||
)
|
||||
|
||||
results = {
|
||||
"repo_id": args.repo_id,
|
||||
"source": args.source,
|
||||
"mode": args.mode,
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"buffer_size": args.buffer_size,
|
||||
"num_cameras": len(meta.video_keys),
|
||||
"fps": meta.fps,
|
||||
"device": str(device),
|
||||
"video_decode_device": args.video_decode_device,
|
||||
"frames_measured": frames,
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
|
||||
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"samples_per_s": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"wallclock_s": round(elapsed, 2),
|
||||
"video_decoder_cache": cache_stats,
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
|
||||
print("Command config:", vars(args))
|
||||
print(json.dumps(results, indent=2))
|
||||
print(f"Wrote {out_dir / tag}.json and .csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isolate the streaming video-decode path — no SLURM, no DataLoader, no benchmark loop.
|
||||
|
||||
Reproduces exactly what StreamingLeRobotDataset does for one video (resolve path -> fsspec.open ->
|
||||
torchcodec VideoDecoder -> get one frame) and prints the environment + the first bytes of the handle, so
|
||||
a decode failure ("No valid stream found in input file") can be pinpointed: bad/placeholder bytes vs a
|
||||
torchcodec/ffmpeg build issue vs a device issue.
|
||||
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id pepijn223/robocasa_pretrain_human300_v4
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --data_files_root hf://buckets/<o>/<n>
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --video_decode_device cuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib.metadata as im
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
|
||||
|
||||
def _version(pkg: str) -> str:
|
||||
try:
|
||||
return im.version(pkg)
|
||||
except Exception:
|
||||
return "MISSING"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--repo_id", required=True)
|
||||
p.add_argument("--data_files_root", default=None, help="e.g. hf://buckets/<owner>/<name>")
|
||||
p.add_argument("--revision", default=None)
|
||||
p.add_argument("--video_decode_device", default="cpu")
|
||||
p.add_argument("--episode", type=int, default=0)
|
||||
args = p.parse_args()
|
||||
|
||||
print("== environment ==")
|
||||
for pkg in ("torchcodec", "av", "huggingface_hub", "hf_xet", "datasets", "fsspec"):
|
||||
print(f" {pkg}: {_version(pkg)}")
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
video_key = meta.video_keys[0]
|
||||
rel_path = meta.get_video_file_path(args.episode, video_key)
|
||||
root = args.data_files_root.rstrip("/") if args.data_files_root else meta.url_root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
print("\n== target ==")
|
||||
print(f" video_key: {video_key}")
|
||||
print(f" video_path: {video_path}")
|
||||
|
||||
print("\n== fsspec handle ==")
|
||||
try:
|
||||
fh = fsspec.open(video_path).__enter__()
|
||||
head = fh.read(32)
|
||||
print(f" first 32 bytes (hex): {head.hex()}")
|
||||
# A valid MP4/MOV has an 'ftyp' box near the start; anything else (HTML/JSON/empty) means the
|
||||
# handle resolved to a placeholder or error page, not the video bytes.
|
||||
looks_mp4 = b"ftyp" in head
|
||||
print(f" looks like MP4 (contains 'ftyp'): {looks_mp4}")
|
||||
if not looks_mp4:
|
||||
print(f" !! first bytes as text: {head[:32]!r}")
|
||||
fh.seek(0)
|
||||
except Exception as e:
|
||||
print(f" !! fsspec.open/read FAILED: {type(e).__name__}: {e}")
|
||||
return
|
||||
|
||||
print("\n== torchcodec VideoDecoder ==")
|
||||
try:
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
decoder = VideoDecoder(fh, seek_mode="approximate", device=args.video_decode_device)
|
||||
md = decoder.metadata
|
||||
print(f" OK: {md.num_frames} frames, {md.average_fps} fps, codec={getattr(md, 'codec', '?')}")
|
||||
frame = decoder.get_frames_at(indices=[0])
|
||||
print(f" decoded frame 0: shape={tuple(frame.data.shape)}, device={frame.data.device}")
|
||||
print("\nDECODE OK — the streaming pipeline can read this video on this machine.")
|
||||
except Exception as e:
|
||||
print(f" !! VideoDecoder FAILED: {type(e).__name__}: {e}")
|
||||
print(
|
||||
"\nDECODE FAILED. If the bytes above look like MP4 (ftyp=True), this is a torchcodec/ffmpeg "
|
||||
"build issue, NOT bad bytes. Common cause for LeRobot v3 datasets: the videos are AV1-encoded "
|
||||
"(see the 'codec' line on a working machine). Then:\n"
|
||||
" - CPU decode needs an ffmpeg built with an AV1 decoder (libdav1d/libaom); a build without it "
|
||||
"reports 'No valid stream found'.\n"
|
||||
" - GPU/NVDEC decode of AV1 is only on AV1-capable NVDEC GPUs: Ada (L4/L40/RTX40) and some "
|
||||
"Ampere (A10/A40/A16). The COMPUTE GPUs A100 and H100 have NO AV1 NVDEC decoder (per NVIDIA's "
|
||||
"support matrix), so no torchcodec build enables cuda decode of AV1 on them.\n"
|
||||
" - 'Unsupported device: cuda (variant: ffmpeg)' instead means torchcodec was built without "
|
||||
"the CUDA backend; install a CUDA-enabled wheel (see README) — but on A100/H100 that still "
|
||||
"won't decode AV1.\n"
|
||||
"Fix: decode on CPU, run NVDEC on an Ada GPU, or re-encode the dataset to H.265/H.264 (which "
|
||||
"A100/H100 NVDEC do support).\n"
|
||||
"If ftyp=False instead, the handle resolved to a placeholder/error page (auth, revision, or Xet "
|
||||
"resolution) rather than the video bytes."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Collapse a directory of benchmark JSON results into one comparison table (and a combined CSV).
|
||||
|
||||
python benchmarks/streaming/summarize_results.py benchmarks/streaming/results
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
COLUMNS = [
|
||||
("source", "source"),
|
||||
("mode", "mode"),
|
||||
("video_decode_device", "decode"),
|
||||
("num_workers", "workers"),
|
||||
("batch_size", "bs"),
|
||||
("frames_per_s_node", "frames/s/node"),
|
||||
("first_batch_latency_s", "first_batch_s"),
|
||||
("p50_sample_latency_ms", "p50_ms"),
|
||||
("p95_sample_latency_ms", "p95_ms"),
|
||||
("p99_sample_latency_ms", "p99_ms"),
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results_dir = Path(sys.argv[1] if len(sys.argv) > 1 else "benchmarks/streaming/results")
|
||||
files = sorted(results_dir.rglob("*.json"))
|
||||
if not files:
|
||||
print(f"No JSON results under {results_dir}")
|
||||
return
|
||||
|
||||
rows = []
|
||||
for f in files:
|
||||
d = json.loads(f.read_text())
|
||||
d["hit_rate"] = d.get("video_decoder_cache", {}).get("hit_rate")
|
||||
rows.append(d)
|
||||
|
||||
rows.sort(key=lambda r: (r.get("source", ""), r.get("mode", ""), r.get("video_decode_device", "")))
|
||||
|
||||
headers = [label for _, label in COLUMNS] + ["cache_hit_rate"]
|
||||
widths = {h: len(h) for h in headers}
|
||||
table = []
|
||||
for r in rows:
|
||||
row = {label: r.get(key, "") for key, label in COLUMNS}
|
||||
row["cache_hit_rate"] = r.get("hit_rate", "")
|
||||
table.append(row)
|
||||
for h in headers:
|
||||
widths[h] = max(widths[h], len(str(row[h])))
|
||||
|
||||
line = " ".join(h.ljust(widths[h]) for h in headers)
|
||||
print(line)
|
||||
print(" ".join("-" * widths[h] for h in headers))
|
||||
for row in table:
|
||||
print(" ".join(str(row[h]).ljust(widths[h]) for h in headers))
|
||||
|
||||
combined = results_dir / "summary.csv"
|
||||
with open(combined, "w", newline="") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(table)
|
||||
print(f"\nWrote {combined}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Distributed, resumable streaming training on a large HF-hosted dataset.
|
||||
|
||||
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
|
||||
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
|
||||
streaming features of :class:`StreamingLeRobotDataset`:
|
||||
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
|
||||
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
|
||||
|
||||
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
|
||||
|
||||
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
|
||||
parser.add_argument(
|
||||
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
|
||||
parser.add_argument("--steps", type=int, default=1000)
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
||||
parser.add_argument("--save_freq", type=int, default=200)
|
||||
parser.add_argument("--log_freq", type=int, default=20)
|
||||
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
|
||||
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_dataloader(
|
||||
args: argparse.Namespace, meta: LeRobotDatasetMetadata
|
||||
) -> tuple[DataLoader, StreamingLeRobotDataset]:
|
||||
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
|
||||
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
|
||||
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
|
||||
dataset = StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
accelerator = Accelerator()
|
||||
output_dir = Path(args.output_dir)
|
||||
if accelerator.is_main_process:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
loader, dataset = make_dataloader(args, meta)
|
||||
|
||||
if args.dummy:
|
||||
model = optimizer = None
|
||||
else:
|
||||
from lerobot.policies.act import ACTConfig, ACTPolicy
|
||||
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
features = dataset_to_policy_features(meta.features)
|
||||
output_features = {k: ft for k, ft in features.items() if k == ACTION}
|
||||
input_features = {k: ft for k, ft in features.items() if k not in output_features}
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
model = ACTPolicy(cfg)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
||||
|
||||
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
|
||||
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
|
||||
# checkpoint this script wrote itself.
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
window_start = time.perf_counter()
|
||||
done = False
|
||||
while not done:
|
||||
for batch in loader:
|
||||
if model is not None:
|
||||
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
loss, _ = model.forward(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
step += 1
|
||||
frames_seen += args.batch_size
|
||||
if step % args.log_freq == 0:
|
||||
elapsed = time.perf_counter() - window_start
|
||||
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
|
||||
total_fps = fps_per_proc * accelerator.num_processes
|
||||
accelerator.print(
|
||||
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
|
||||
+ ("" if model is None else f" | loss {loss.item():.3f}")
|
||||
)
|
||||
window_start = time.perf_counter()
|
||||
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
|
||||
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
|
||||
if model is not None:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
if step >= args.steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -214,6 +214,7 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
@@ -296,6 +297,7 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench_stream
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=02:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Per-node dataloading benchmark for StreamingLeRobotDataset across 1-2 nodes. Each node runs an
|
||||
# independent dummy-consumer benchmark; per-node throughput should be independent (separate network).
|
||||
# Results are written per (node, source, mode) under --out_dir.
|
||||
#
|
||||
# Submit with: sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
# Override the source label for cold/warm bucket runs: SOURCE=warmed_bucket sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
SOURCE=${SOURCE:-hub}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# One benchmark process per node (each saturates the node's DataLoader workers + network independently).
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
for MODE in single sarm; do
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--source '"$SOURCE"' \
|
||||
--mode $MODE \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--num_batches 300 \
|
||||
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
|
||||
done
|
||||
'
|
||||
@@ -1,100 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Submit the FULL streaming dataloading-benchmark matrix as isolated single-GPU SLURM jobs.
|
||||
#
|
||||
# sources : hub (Hub streaming) | bucket (cold HF bucket) | warmed_bucket (prewarmed HF bucket)
|
||||
# modes : single (1 frame, all cameras) | sarm (8-step / 8s delta window)
|
||||
# decode : cpu (torchcodec on CPU, scales with workers) | cuda (NVDEC, offloads decode to the GPU)
|
||||
#
|
||||
# => 3 x 2 x 2 = 12 jobs. Each runs in its OWN job (1 node, 1 GPU) so an OOM is isolated and reported
|
||||
# per-job by SLURM (check `sacct -j <id> --format=JobID,State,MaxRSS,ReqMem`). Submit from a login node
|
||||
# inside the repo: bash slurm/run_streaming_matrix.sh
|
||||
#
|
||||
# SERIAL (default 1): chain the jobs with --dependency=afterany so SLURM runs exactly ONE at a time. This
|
||||
# is important for a bandwidth benchmark — concurrent jobs would share the network to the Hub/bucket and
|
||||
# corrupt every throughput number. `afterany` means a failed/OOM'd job does not stall the chain. Set
|
||||
# SERIAL=0 to let the scheduler run them in parallel (only for OOM-isolation testing, not for throughput).
|
||||
#
|
||||
# Knobs (env overrides):
|
||||
# REPO_ID, BUCKET, WARM_BUCKET, OUT_DIR, NUM_BATCHES, TIME, MEM, GPUS, SERIAL
|
||||
# CPU_WORKERS / CPU_BUFFER (cpu-decode jobs) GPU_WORKERS / GPU_BUFFER (cuda-decode jobs, kept low to
|
||||
# bound VRAM + NVDEC sessions). RUN ("python" by default; set RUN="uv run python" if using uv).
|
||||
# SOURCES / MODES / DECODES to run a subset (e.g. SOURCES="hub bucket" DECODES="cpu").
|
||||
# ACCOUNT / PARTITION / QOS passed through to sbatch if set.
|
||||
set -euo pipefail
|
||||
|
||||
REPO_DIR=$(git rev-parse --show-toplevel)
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
BUCKET=${BUCKET:-hf://buckets/pepijn223/robocasa-stream}
|
||||
WARM_BUCKET=${WARM_BUCKET:-hf://buckets/pepijn223/robocasa-stream-warm}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
NUM_BATCHES=${NUM_BATCHES:-200}
|
||||
TIME=${TIME:-01:00:00}
|
||||
MEM=${MEM:-64G}
|
||||
GPUS=${GPUS:-1}
|
||||
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
|
||||
CPU_WORKERS=${CPU_WORKERS:-8}
|
||||
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
|
||||
CPU_BUFFER=${CPU_BUFFER:-4000}
|
||||
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
|
||||
BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
RUN=${RUN:-python}
|
||||
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
|
||||
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
|
||||
# + datasets (>=4.7) — the default `base` env on many clusters is too old to decode AV1 / lacks CUDA.
|
||||
CONDA_ENV=${CONDA_ENV:-}
|
||||
if [ -n "$CONDA_ENV" ] && [ "$RUN" = "python" ]; then
|
||||
RUN="conda run --no-capture-output -n $CONDA_ENV python"
|
||||
fi
|
||||
|
||||
SOURCES=${SOURCES:-"hub bucket warmed_bucket"}
|
||||
MODES=${MODES:-"single sarm"}
|
||||
DECODES=${DECODES:-"cpu cuda"}
|
||||
|
||||
mkdir -p "$REPO_DIR/logs" "$REPO_DIR/$OUT_DIR"
|
||||
|
||||
data_root_for () {
|
||||
case "$1" in
|
||||
hub) echo "" ;;
|
||||
bucket) echo "$BUCKET" ;;
|
||||
warmed_bucket) echo "$WARM_BUCKET" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
n=0
|
||||
prev_jid=""
|
||||
for SOURCE in $SOURCES; do
|
||||
DATA_ROOT=$(data_root_for "$SOURCE")
|
||||
ROOTFLAG=""
|
||||
[ -n "$DATA_ROOT" ] && ROOTFLAG="--data_files_root $DATA_ROOT"
|
||||
for MODE in $MODES; do
|
||||
for DECODE in $DECODES; do
|
||||
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
|
||||
# Run strictly after the previous job so only one job touches the network at a time.
|
||||
DEPFLAG=""
|
||||
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
|
||||
jid=$(sbatch --parsable \
|
||||
--job-name="bench_${SOURCE}_${MODE}_${DECODE}" \
|
||||
--nodes=1 --ntasks=1 --gpus="$GPUS" --cpus-per-task=$((W + 4)) \
|
||||
--mem="$MEM" --time="$TIME" --output="$REPO_DIR/logs/%x-%j.out" \
|
||||
$DEPFLAG \
|
||||
${ACCOUNT:+--account=$ACCOUNT} ${PARTITION:+--partition=$PARTITION} ${QOS:+--qos=$QOS} \
|
||||
--wrap "cd '$REPO_DIR' && \
|
||||
export TOKENIZERS_PARALLELISM=false && export HF_HOME=\${HF_HOME:-\$SCRATCH/hf_home} && \
|
||||
$RUN benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id $REPO_ID $ROOTFLAG \
|
||||
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
|
||||
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
|
||||
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
|
||||
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
|
||||
prev_jid=$jid
|
||||
n=$((n + 1))
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
|
||||
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
|
||||
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
|
||||
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=stream_robocasa
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Multinode streaming training over a large HF-hosted RoboCasa dataset (never touches local disk).
|
||||
# Launches examples/scaling/train_streaming_multinode.py with Accelerate. Each rank streams a disjoint
|
||||
# set of shards via split_dataset_by_node (auto-resolved from the Accelerate state), so per-node
|
||||
# throughput scales independently. For an even split, ensure n_shards % (nodes * gpus_per_node) == 0.
|
||||
#
|
||||
# Submit with: sbatch slurm/train_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
GPUS_PER_NODE=8
|
||||
NUM_PROCESSES=$((SLURM_NNODES * GPUS_PER_NODE))
|
||||
|
||||
# Rendezvous: use the first node in the allocation as the main process.
|
||||
MAIN_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
|
||||
MAIN_PORT=${MAIN_PORT:-29500}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
# Avoid each rank fighting over the tokenizers' internal thread pool.
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
accelerate launch \
|
||||
--num_machines '"$SLURM_NNODES"' \
|
||||
--num_processes '"$NUM_PROCESSES"' \
|
||||
--machine_rank $SLURM_NODEID \
|
||||
--main_process_ip '"$MAIN_ADDR"' \
|
||||
--main_process_port '"$MAIN_PORT"' \
|
||||
--mixed_precision bf16 \
|
||||
--dynamo_backend no \
|
||||
examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--steps 200000 \
|
||||
--save_freq 2000 \
|
||||
--log_freq 50
|
||||
'
|
||||
@@ -39,9 +39,6 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
|
||||
# of host RAM. Ignored when streaming is False.
|
||||
streaming_buffer_size: int = 1000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
buffer_size=cfg.dataset.streaming_buffer_size,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -13,9 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
@@ -24,7 +21,6 @@ import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
@@ -42,8 +38,6 @@ from .video_utils import (
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
@@ -258,11 +252,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
rank: int | None = None,
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
video_decode_device: str = "cpu",
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -283,25 +272,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training.
|
||||
Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is
|
||||
resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
|
||||
world_size (int | None, optional): Total number of distributed processes. When omitted, resolved
|
||||
from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding).
|
||||
For an even per-rank split, ``num_shards % world_size == 0`` should hold.
|
||||
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
|
||||
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
|
||||
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
|
||||
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
|
||||
trees (e.g. an HF storage bucket ``hf://buckets/<owner>/<name>``). When set, parquet and
|
||||
video frames are read from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket
|
||||
sources without copying the (small) metadata.
|
||||
video_decode_device (str, optional): Device for video decoding, passed to the torchcodec
|
||||
``VideoDecoder``. Defaults to ``"cpu"``. Set to ``"cuda"`` to offload H.264/H.265 decode to
|
||||
the GPU's dedicated NVDEC engine (independent of the training SMs), which requires a
|
||||
CUDA-enabled torchcodec build. Note: ``"cuda"`` decode inside ``DataLoader`` workers needs
|
||||
the ``spawn`` start method (CUDA cannot init in forked workers).
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -319,21 +289,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self.max_num_shards = max_num_shards
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
self.video_decode_device = video_decode_device
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
|
||||
# one place the main process can read after iteration (see video_decoder_cache_stats()).
|
||||
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
|
||||
# Resume state captured by load_state_dict() and consumed at the next __iter__.
|
||||
self._resume_state: dict | None = None
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -355,31 +314,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
if self.data_files_root is not None:
|
||||
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
"parquet",
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files=f"{self.data_files_root}/data/*/*.parquet",
|
||||
)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Drop any parquet columns not declared in the dataset's feature contract. Some revisions / sources
|
||||
# (e.g. an unversioned bucket holding `main`) carry extra, possibly variable-length annotation
|
||||
# columns such as `language_events`; left in, they leak into the sample and break default DataLoader
|
||||
# collation across frames of differing length. On a clean revision this is a no-op.
|
||||
known_columns = set(self.meta.features)
|
||||
extra_columns = [c for c in (self.hf_dataset.column_names or []) if c not in known_columns]
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
@@ -407,99 +348,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
|
||||
"""Resolve (rank, world_size) for distributed streaming.
|
||||
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
try:
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if PartialState._shared_state: # only read it if already initialized; never initialize here
|
||||
state = PartialState()
|
||||
return state.process_index, state.num_processes
|
||||
except Exception:
|
||||
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
|
||||
|
||||
env_rank = os.environ.get("RANK")
|
||||
env_world = os.environ.get("WORLD_SIZE")
|
||||
if env_rank is not None and env_world is not None:
|
||||
return int(env_rank), int(env_world)
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the working set of live shards so it does not thrash.
|
||||
|
||||
Each shard mid-episode keeps one open decoder per camera; with several shards iterated
|
||||
concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of
|
||||
margin so the round-robin never evicts a still-live decoder.
|
||||
"""
|
||||
if self.video_decoder_cache_size is not None:
|
||||
return VideoDecoderCache(
|
||||
max_size=self.video_decoder_cache_size,
|
||||
counters=self._cache_counters,
|
||||
device=self.video_decode_device,
|
||||
)
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if num_cameras == 0:
|
||||
return VideoDecoderCache(counters=self._cache_counters, device=self.video_decode_device)
|
||||
return VideoDecoderCache(
|
||||
max_size=(num_active_shards + 1) * num_cameras,
|
||||
counters=self._cache_counters,
|
||||
device=self.video_decode_device,
|
||||
)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
# Distributed correctness: each rank streams a disjoint set of shards (order preserved).
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
num_shards = min(ds.num_shards, self.max_num_shards)
|
||||
shard_indices = list(range(num_shards))
|
||||
|
||||
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
shard_indices = shard_indices[worker_info.id :: worker_info.num_workers]
|
||||
|
||||
self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices))
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
# Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The
|
||||
# shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable
|
||||
# for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead.
|
||||
resume = self._resume_state
|
||||
self._resume_state = None
|
||||
self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set()
|
||||
if resume is not None:
|
||||
rng.bit_generator.state = resume["rng"]
|
||||
|
||||
self._shards: dict[int, datasets.IterableDataset] = {}
|
||||
for idx in shard_indices:
|
||||
shard = safe_shard(ds, idx, num_shards)
|
||||
if resume is not None and str(idx) in resume["shards"]:
|
||||
shard.load_state_dict(resume["shards"][str(idx)])
|
||||
self._shards[idx] = shard
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(shard)
|
||||
for idx, shard in self._shards.items()
|
||||
if idx not in self._exhausted
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
@@ -525,47 +389,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
self._exhausted.add(shard_key)
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
|
||||
|
||||
Must be called after iteration has started (so the shard streams exist). Restore the returned
|
||||
dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so
|
||||
resumption is not bit-exact — see :meth:`__iter__`.
|
||||
"""
|
||||
if not hasattr(self, "_shards"):
|
||||
raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.")
|
||||
return {
|
||||
"shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()},
|
||||
"exhausted": sorted(self._exhausted),
|
||||
"rng": self.rng.bit_generator.state,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
|
||||
self._resume_state = state_dict
|
||||
|
||||
def video_decoder_cache_stats(self) -> dict[str, int | float]:
|
||||
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
|
||||
|
||||
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
|
||||
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
|
||||
approximate; the ``hit_rate`` ratio is preserved.
|
||||
"""
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
|
||||
total = hits + misses
|
||||
return {
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"evictions": evictions,
|
||||
"hit_rate": round(hits / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
@@ -577,23 +405,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically size the windows to exactly cover the requested delta_timestamps (in frames).
|
||||
# This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for
|
||||
# long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20).
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = math.floor(min(all_timestamps) * self.fps)
|
||||
lookahead = math.ceil(max(all_timestamps) * self.fps)
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else -lookback
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True)
|
||||
# Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires
|
||||
# history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1.
|
||||
return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead))
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
@@ -649,20 +473,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# `timestamp` is episode-local (restarts at 0 each episode). The absolute in-file timestamp is
|
||||
# `from_timestamp + timestamp`, applied per camera at decode time (see `_query_videos`), mirroring
|
||||
# the map-style reader. Using `index / fps` here is a dataset-global value that only matches the
|
||||
# file timeline when the whole dataset is a single video (e.g. small test fixtures), and otherwise
|
||||
# decodes out-of-range frames on multi-file v3 datasets.
|
||||
current_ts = float(item["timestamp"])
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
# Per-camera episode-local bounds [0, duration]. Query timestamps are clamped into this range so
|
||||
# out-of-episode deltas pad rather than decode against a neighbouring episode in the same file.
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
0.0,
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
|
||||
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
@@ -735,19 +552,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
root = self.root
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
|
||||
@@ -242,12 +242,7 @@ class VideoDecoderCache:
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int | None | object = _SENTINEL,
|
||||
counters: "torch.Tensor | None" = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
@@ -255,18 +250,6 @@ class VideoDecoderCache:
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
# Decode device for the underlying torchcodec VideoDecoder. "cuda" offloads H.264/H.265 decode to
|
||||
# the GPU's dedicated NVDEC engine (independent of the SMs used for training); requires a
|
||||
# CUDA-enabled torchcodec/FFmpeg build. See https://developer.nvidia.com/video-codec-sdk.
|
||||
self.device = device
|
||||
# Observability counters (cheap, updated under the lock) for benchmarking decoder reuse.
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evictions = 0
|
||||
# Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place
|
||||
# (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so
|
||||
# treat the aggregate as approximate; the hit-rate ratio is preserved.
|
||||
self._counters = counters
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
@@ -288,17 +271,11 @@ class VideoDecoderCache:
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
self.hits += 1
|
||||
if self._counters is not None:
|
||||
self._counters[0] += 1
|
||||
return entry[0]
|
||||
|
||||
self.misses += 1
|
||||
if self._counters is not None:
|
||||
self._counters[1] += 1
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
@@ -310,9 +287,6 @@ class VideoDecoderCache:
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
self.evictions += 1
|
||||
if self._counters is not None:
|
||||
self._counters[2] += 1
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
@@ -331,18 +305,6 @@ class VideoDecoderCache:
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> dict[str, int | float]:
|
||||
"""Return reuse counters (hits/misses/evictions, hit rate, current size) for benchmarking."""
|
||||
with self._lock:
|
||||
total = self.hits + self.misses
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"hit_rate": self.hits / total if total else 0.0,
|
||||
"size": len(self._cache),
|
||||
}
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -26,6 +29,7 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
+567
@@ -0,0 +1,567 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
+235
@@ -0,0 +1,235 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -63,6 +64,12 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -96,6 +103,8 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -191,6 +200,16 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -1,251 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
|
||||
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
stop_after = 40
|
||||
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
|
||||
state = ds.state_dict()
|
||||
assert set(state) == {"shards", "exhausted", "rng"}
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
# Resume continues rather than replaying: the full first pass is not re-yielded.
|
||||
assert len(resumed) < total_frames
|
||||
overlap = set(seen_before) & set(resumed)
|
||||
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
|
||||
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
|
||||
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
|
||||
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
@@ -0,0 +1,518 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
Reference in New Issue
Block a user