mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 14:39:44 +00:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3ec60da82b | |||
| 7bcd5a1502 | |||
| 674c990a39 | |||
| 38106ea6b4 | |||
| 894fc6bfb5 | |||
| 984b400e5c | |||
| 4e056081cb | |||
| a164bb97bd | |||
| 79b547de32 | |||
| a7b7f4964e | |||
| 1050c2fb6c | |||
| 66ac901632 | |||
| ce326207e6 | |||
| 2ab71231cd | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 | |||
| 42d4788e4a | |||
| 2d1c17d971 | |||
| 7241f029c6 | |||
| 06ddc59913 | |||
| 23c58f5f9e | |||
| b0ab57cedc | |||
| afdc084677 | |||
| a32a2c647b | |||
| 343ecd7980 | |||
| f7c8a526e8 | |||
| 77af66a29c | |||
| 68fa5d80b0 | |||
| d1fc8e298c | |||
| 49755a3d9e | |||
| 09808183ca |
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -0,0 +1,531 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
|
||||
|
||||
This one file is both the orchestrator and the worker:
|
||||
|
||||
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
|
||||
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
|
||||
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
|
||||
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
|
||||
streaming dataset splits shards per node.
|
||||
|
||||
Scenarios (all single-frame / non-SARM):
|
||||
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
|
||||
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
|
||||
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
|
||||
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
|
||||
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
|
||||
|
||||
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
|
||||
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
|
||||
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
|
||||
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
|
||||
sample latency. Results are written as JSON + CSV under ``--out_dir``.
|
||||
|
||||
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
|
||||
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
|
||||
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
|
||||
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
|
||||
|
||||
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
|
||||
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
|
||||
# it; "main" pins the branch directly and skips that check.
|
||||
MOLMO_REVISION = "main"
|
||||
|
||||
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
|
||||
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
|
||||
# ``cpus`` override the CLI defaults for that leg.
|
||||
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
|
||||
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
|
||||
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
|
||||
SCENARIOS = {
|
||||
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
|
||||
"mmap_local_maxworkers": {
|
||||
"kind": "map",
|
||||
"nodes": 1,
|
||||
"mem": "128G",
|
||||
"time": "01:00:00",
|
||||
"num_workers": 30,
|
||||
"cpus": 92,
|
||||
},
|
||||
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
|
||||
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
|
||||
}
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
class _TimedStreaming(StreamingLeRobotDataset):
|
||||
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
|
||||
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
|
||||
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
|
||||
library, to keep the dataset itself instrumentation-free."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fetch_s = 0.0
|
||||
self.decode_s = 0.0
|
||||
|
||||
def __iter__(self):
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
t1 = time.perf_counter()
|
||||
sample = self._finalize_sample(row)
|
||||
t2 = time.perf_counter()
|
||||
self.fetch_s += t1 - t0
|
||||
self.decode_s += t2 - t1
|
||||
yield sample
|
||||
|
||||
|
||||
def select_node_episodes(
|
||||
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
|
||||
) -> list[int]:
|
||||
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
|
||||
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
|
||||
episodes = list(range(meta.total_episodes))
|
||||
from_idx = meta.episodes["dataset_from_index"]
|
||||
to_idx = meta.episodes["dataset_to_index"]
|
||||
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
|
||||
if meta.video_keys:
|
||||
file_columns = {
|
||||
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
|
||||
for key in meta.video_keys
|
||||
}
|
||||
else:
|
||||
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
|
||||
episode_file_ids = [
|
||||
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
|
||||
]
|
||||
groups = group_episodes_by_files(episode_file_ids)
|
||||
if len(groups) < num_partitions:
|
||||
groups = [[i] for i in range(len(episodes))]
|
||||
group_lengths = [sum(lengths[i] for i in g) for g in groups]
|
||||
bins = partition_episodes(group_lengths, num_partitions)
|
||||
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
|
||||
return chosen[:cap] if cap and len(chosen) > cap else chosen
|
||||
|
||||
|
||||
def build_dataset(scenario: str, args: argparse.Namespace):
|
||||
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
|
||||
if scenario.startswith("mmap_local"):
|
||||
if not args.local_root:
|
||||
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
|
||||
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
|
||||
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
|
||||
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
|
||||
return dataset, meta, True, {"loaded_episodes": len(episodes)}
|
||||
|
||||
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
|
||||
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
|
||||
dataset = _TimedStreaming(
|
||||
MOLMO_REPO,
|
||||
revision=MOLMO_REVISION,
|
||||
data_files_root=data_files_root,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
|
||||
|
||||
|
||||
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
|
||||
stage = fetch_s + decode_s
|
||||
return {
|
||||
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
|
||||
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
|
||||
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
|
||||
}
|
||||
|
||||
|
||||
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
|
||||
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
|
||||
it = iter(dataset)
|
||||
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
|
||||
next(it)
|
||||
dataset.fetch_s = dataset.decode_s = 0.0
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n_probe):
|
||||
next(it)
|
||||
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
|
||||
|
||||
|
||||
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
|
||||
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
|
||||
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
|
||||
|
||||
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
|
||||
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
|
||||
rng = random.Random(0)
|
||||
n = len(dataset)
|
||||
fetch_s = getitem_s = 0.0
|
||||
warmed = measured = skipped = attempts = 0
|
||||
while measured < n_probe and attempts < (warmup + n_probe) * 10:
|
||||
attempts += 1
|
||||
i = rng.randrange(n)
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
dataset.get_raw_item(i)
|
||||
t1 = time.perf_counter()
|
||||
dataset[i]
|
||||
t2 = time.perf_counter()
|
||||
except Exception:
|
||||
skipped += 1
|
||||
continue
|
||||
if warmed < warmup:
|
||||
warmed += 1
|
||||
continue
|
||||
fetch_s += t1 - t0
|
||||
getitem_s += t2 - t1
|
||||
measured += 1
|
||||
if skipped:
|
||||
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
|
||||
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
|
||||
|
||||
|
||||
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
device = torch.device(args.device)
|
||||
|
||||
dataset, meta, is_map_style, info = build_dataset(scenario, args)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
persistent_workers=args.num_workers > 0,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = []
|
||||
samples = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
if i == args.warmup_batches:
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
samples += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
|
||||
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
|
||||
if steady_start is not None and (now - steady_start) >= args.duration_s:
|
||||
break
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
|
||||
if samples == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
|
||||
"(inspect worker logs; try --num_workers 0 to surface the exception)."
|
||||
)
|
||||
|
||||
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
|
||||
# decodes video in the main process, which must stay decode-clean until the workers have forked
|
||||
# (decoding before fork corrupts the workers' torchcodec state).
|
||||
del loader
|
||||
if is_map_style:
|
||||
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
|
||||
else:
|
||||
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
|
||||
|
||||
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
num_cameras = len(meta.video_keys)
|
||||
results = {
|
||||
"scenario": scenario,
|
||||
"rank": rank,
|
||||
"world_size": world_size,
|
||||
"loader": "map_style" if is_map_style else "streaming",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"episode_pool_size": None if is_map_style else args.episode_pool_size,
|
||||
**info,
|
||||
"num_cameras": num_cameras,
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"samples_measured": samples,
|
||||
"steady_window_s": round(steady_elapsed_s, 2),
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
|
||||
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
|
||||
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
|
||||
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
|
||||
if steady_elapsed_s
|
||||
else 0.0,
|
||||
**fetch_decode,
|
||||
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
print(json.dumps(results, indent=2), flush=True)
|
||||
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
|
||||
|
||||
|
||||
def submit_chain(args: argparse.Namespace) -> None:
|
||||
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
|
||||
|
||||
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
|
||||
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
|
||||
"""
|
||||
this_file = Path(__file__).resolve()
|
||||
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
|
||||
logs = Path(repo_dir) / "logs"
|
||||
logs.mkdir(exist_ok=True)
|
||||
run = f"conda run --no-capture-output -n {args.conda_env} python"
|
||||
common = (
|
||||
f"--batch_size {args.batch_size} "
|
||||
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
|
||||
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
|
||||
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
|
||||
)
|
||||
if args.local_root:
|
||||
common += f" --local_root {args.local_root}"
|
||||
env_prefix = "export TOKENIZERS_PARALLELISM=false"
|
||||
sched = []
|
||||
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
|
||||
if os.environ.get(env):
|
||||
sched.append(f"{opt}={os.environ[env]}")
|
||||
|
||||
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
|
||||
prev = ""
|
||||
for scenario in selected:
|
||||
cfg = SCENARIOS[scenario]
|
||||
nw = cfg.get("num_workers", args.num_workers)
|
||||
cpus = cfg.get("cpus", nw + 4)
|
||||
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
|
||||
if cfg["nodes"] > 1:
|
||||
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
|
||||
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
|
||||
body = f"srun --export=ALL bash -c '{inner}'"
|
||||
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
|
||||
else:
|
||||
body = f"cd {repo_dir} && {env_prefix} && {worker}"
|
||||
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
|
||||
cmd = [
|
||||
"sbatch",
|
||||
"--parsable",
|
||||
f"--job-name=dlbench_{scenario}",
|
||||
*node_flags,
|
||||
f"--cpus-per-task={cpus}",
|
||||
f"--mem={cfg['mem']}",
|
||||
f"--time={cfg['time']}",
|
||||
f"--output={logs}/%x-%j.out",
|
||||
*sched,
|
||||
]
|
||||
if prev:
|
||||
cmd.append(f"--dependency=afterany:{prev}")
|
||||
cmd += ["--wrap", body]
|
||||
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
|
||||
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
|
||||
prev = jid
|
||||
|
||||
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument(
|
||||
"--scenario",
|
||||
choices=list(SCENARIOS),
|
||||
default=None,
|
||||
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--scenarios",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
|
||||
)
|
||||
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
|
||||
p.add_argument(
|
||||
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
|
||||
)
|
||||
p.add_argument("--partition_index", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=64)
|
||||
p.add_argument("--num_workers", type=int, default=8)
|
||||
p.add_argument("--prefetch_factor", type=int, default=2)
|
||||
p.add_argument(
|
||||
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
|
||||
)
|
||||
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
|
||||
p.add_argument(
|
||||
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
|
||||
)
|
||||
p.add_argument(
|
||||
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
|
||||
)
|
||||
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
|
||||
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.scenario is None:
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
|
||||
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
submit_chain(args)
|
||||
else:
|
||||
run_scenario(args.scenario, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+8
-4
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"datasets>=4.7.0,<6.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -333,6 +333,10 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation, fixed in datasets#8259 (merged, not yet released). Pin to the merge commit until the
|
||||
# next datasets release ships it, then drop this and bump the floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -39,6 +39,10 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 1024
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
@@ -13,16 +13,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from .feature_utils import get_delta_indices
|
||||
@@ -31,207 +32,56 @@ from .utils import (
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
is_float_in_list,
|
||||
safe_shard,
|
||||
)
|
||||
from .video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable[T]:
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
# Bound the default frame-level shuffle buffer: rows are tabular-only (~KB each), so this is
|
||||
# roughly a few hundred MB of host RAM per consumer at the cap.
|
||||
_MAX_DEFAULT_FRAME_BUFFER = 200_000
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
"""LeRobotDataset with streaming capabilities, built on native HF `datasets` primitives.
|
||||
|
||||
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||
The tabular side is a pure `datasets` pipeline::
|
||||
|
||||
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||
dataset into memory.
|
||||
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
|
||||
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
|
||||
-> batch(by_column="episode_index") # whole episodes
|
||||
-> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob)
|
||||
-> map(explode + exact delta windows) # episode -> frames, windows are exact
|
||||
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
|
||||
|
||||
and this class is a thin torch ``IterableDataset`` wrapper around it that decodes video
|
||||
per emitted sample (decode-on-exit), applies image transforms, and attaches the task
|
||||
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
|
||||
and resume uses the native ``state_dict`` / ``load_state_dict``.
|
||||
|
||||
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
|
||||
exact slices of the resident episode with correct padding at episode boundaries.
|
||||
|
||||
Resume: ``state_dict()`` / ``load_state_dict()`` delegate to `datasets`. Samples sitting in
|
||||
the shuffle buffers at checkpoint time are skipped on resume (documented `datasets`
|
||||
behavior), so resume never repeats data but may drop up to roughly
|
||||
``episode_pool_size x episode_len + frame_shuffle_buffer_size`` frames — negligible at
|
||||
training scale. The contract is exact with ``num_workers=0``; with DataLoader workers use
|
||||
``torchdata.stateful_dataloader.StatefulDataLoader``, which checkpoints each worker's
|
||||
dataset state through this same protocol.
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
```python
|
||||
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Create a streaming dataset with delta timestamps
|
||||
delta_timestamps = {
|
||||
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id="your-dataset-repo-id",
|
||||
delta_timestamps=delta_timestamps,
|
||||
streaming=True,
|
||||
buffer_size=1000,
|
||||
delta_timestamps={"action": [0.0, 0.1, 0.2]},
|
||||
episode_pool_size=1024,
|
||||
)
|
||||
|
||||
# Iterate over the dataset
|
||||
for i, item in enumerate(dataset):
|
||||
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||
# item will contain stacked frames according to delta_timestamps
|
||||
if i >= 10:
|
||||
break
|
||||
for sample in dataset:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -246,12 +96,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
buffer_size: int = 1000,
|
||||
max_num_shards: int = 16,
|
||||
episode_pool_size: int | None = 1024,
|
||||
frame_shuffle_buffer_size: int | None = None,
|
||||
buffer_size: int | None = None,
|
||||
max_num_shards: int | None = None,
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
rank: int | None = None,
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -267,11 +123,30 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||
episode_pool_size (int, optional): Whole episodes each consumer keeps open to shuffle
|
||||
across — the randomness knob. Larger mixes more episodes per batch (closer to
|
||||
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
|
||||
Defaults to 1024.
|
||||
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
|
||||
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
|
||||
which matches the pool's mixing radius.
|
||||
buffer_size (int | None, optional): Deprecated; superseded by ``episode_pool_size``.
|
||||
max_num_shards (int | None, optional): Deprecated; `datasets` handles shard-to-worker
|
||||
assignment natively.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
rng (np.random.Generator | None, optional): Deprecated; ignored.
|
||||
shuffle (bool, optional): Whether to shuffle. False yields episodes in stream order.
|
||||
rank (int | None, optional): This process' rank for distributed training. Each rank streams
|
||||
a disjoint set of shards via ``split_dataset_by_node``. When omitted, resolved from
|
||||
Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
|
||||
world_size (int | None, optional): Total number of distributed processes. When omitted,
|
||||
resolved from Accelerate or ``WORLD_SIZE``, defaulting to 1. For an even per-rank split,
|
||||
``num_shards % world_size == 0`` should hold (warned otherwise).
|
||||
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
|
||||
When omitted, sized to the episode pool's working set, capped at 128.
|
||||
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
|
||||
trees (e.g. ``hf://buckets/<owner>/<name>``). When set, parquet and video bytes are read
|
||||
from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -284,15 +159,32 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.seed = seed
|
||||
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||
if rng is not None:
|
||||
logger.warning("StreamingLeRobotDataset: `rng` is deprecated and ignored; use `seed`.")
|
||||
if buffer_size is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; "
|
||||
"use `episode_pool_size` (whole episodes, not frames)."
|
||||
)
|
||||
if max_num_shards is not None:
|
||||
logger.warning(
|
||||
"StreamingLeRobotDataset: `max_num_shards` is deprecated and ignored; "
|
||||
"`datasets` assigns shards to DataLoader workers natively."
|
||||
)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 1024
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
self._epoch = 0
|
||||
self._in_flight_epoch = 0
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -314,15 +206,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
if self.data_files_root is not None:
|
||||
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
"parquet",
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files=f"{self.data_files_root}/data/*/*.parquet",
|
||||
)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Drop any parquet columns not declared in the dataset's feature contract. Some revisions / sources
|
||||
# (e.g. an unversioned bucket holding `main`) carry extra, possibly variable-length annotation
|
||||
# columns such as `language_events`; left in, they leak into the sample and break default DataLoader
|
||||
# collation across frames of differing length. On a clean revision this is a no-op.
|
||||
known_columns = set(self.meta.features)
|
||||
extra_columns = [c for c in (self.hf_dataset.column_names or []) if c not in known_columns]
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
|
||||
self.num_shards = self.hf_dataset.num_shards
|
||||
|
||||
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
|
||||
self.frame_shuffle_buffer_size = (
|
||||
frame_shuffle_buffer_size
|
||||
if frame_shuffle_buffer_size is not None
|
||||
else min(self.episode_pool_size * avg_episode_len, _MAX_DEFAULT_FRAME_BUFFER)
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
self._pipeline = self._build_pipeline()
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
@@ -337,96 +256,185 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
return self.meta.fps
|
||||
|
||||
@staticmethod
|
||||
def _iter_random_indices(
|
||||
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||
) -> Iterator[int]:
|
||||
while True:
|
||||
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
|
||||
"""Resolve (rank, world_size) for distributed streaming.
|
||||
|
||||
@staticmethod
|
||||
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
import os
|
||||
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
try:
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if PartialState._shared_state: # only read it if already initialized; never initialize here
|
||||
state = PartialState()
|
||||
return state.process_index, state.num_processes
|
||||
except Exception:
|
||||
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
|
||||
|
||||
env_rank = os.environ.get("RANK")
|
||||
env_world = os.environ.get("WORLD_SIZE")
|
||||
if env_rank is not None and env_world is not None:
|
||||
return int(env_rank), int(env_world)
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _build_pipeline(self) -> datasets.IterableDataset:
|
||||
"""Assemble the native tabular pipeline (everything except video decode)."""
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
if ds.num_shards % self.world_size != 0:
|
||||
logger.warning(
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
|
||||
"datasets falls back to example-level splitting where every rank reads (and pays "
|
||||
"for) the full stream. Re-shard the dataset or adjust world size."
|
||||
)
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
ds = ds.batch(by_column="episode_index")
|
||||
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size)
|
||||
# A row-count-changing batched map must drop the input columns explicitly; the exploded
|
||||
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
|
||||
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(seed=self.seed + 1, buffer_size=max(2, self.frame_shuffle_buffer_size))
|
||||
return ds
|
||||
|
||||
def _tabular_window_keys(self) -> list[str]:
|
||||
if self.delta_indices is None:
|
||||
return []
|
||||
return [key for key in self.delta_indices if key not in self.meta.video_keys]
|
||||
|
||||
def _explode_episodes(self, episode_batch: dict[str, list[list]]) -> dict[str, list]:
|
||||
"""Episode batches -> per-frame rows, with exact tabular delta windows and pad masks.
|
||||
|
||||
Runs inside the `datasets` pipeline (plain Python values, no torch). For each windowed key
|
||||
the original per-frame value is replaced by its delta window (list of values, clamped to
|
||||
the episode bounds) plus a ``{key}_is_pad`` mask, mirroring the map-style dataset.
|
||||
"""
|
||||
window_keys = set(self._tabular_window_keys())
|
||||
out: dict[str, list] = {key: [] for key in episode_batch if key not in window_keys}
|
||||
for key in window_keys:
|
||||
out[key] = []
|
||||
out[f"{key}_is_pad"] = []
|
||||
|
||||
num_episodes = len(episode_batch["episode_index"])
|
||||
for e in range(num_episodes):
|
||||
length = len(episode_batch["episode_index"][e])
|
||||
for key, column in episode_batch.items():
|
||||
if key in window_keys:
|
||||
continue
|
||||
out[key].extend(column[e])
|
||||
for key in window_keys:
|
||||
episode_column = episode_batch[key][e]
|
||||
deltas = self.delta_indices[key]
|
||||
for t in range(length):
|
||||
window = []
|
||||
is_pad = []
|
||||
for delta in deltas:
|
||||
j = t + delta
|
||||
window.append(episode_column[min(max(j, 0), length - 1)])
|
||||
is_pad.append(not 0 <= j < length)
|
||||
out[key].append(window)
|
||||
out[f"{key}_is_pad"].append(is_pad)
|
||||
return out
|
||||
|
||||
def _make_video_decoder_cache(self) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the pool's working set (pool episodes x cameras), capped at 128."""
|
||||
if self.video_decoder_cache_size is not None:
|
||||
return VideoDecoderCache(max_size=self.video_decoder_cache_size)
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if num_cameras == 0:
|
||||
return VideoDecoderCache()
|
||||
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
|
||||
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
|
||||
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
|
||||
# stream position actually belongs to.
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
try:
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
break # random shard sampled, switch shard
|
||||
except (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
row = next(iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
yield self._finalize_sample(row)
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
def _finalize_sample(self, row: dict) -> dict:
|
||||
"""Torch conversion + video decode (decode-on-exit) + transforms + task for one frame."""
|
||||
window_keys = self._tabular_window_keys()
|
||||
pad_masks = {f"{key}_is_pad": torch.BoolTensor(row.pop(f"{key}_is_pad")) for key in window_keys}
|
||||
item = item_to_torch(row)
|
||||
item.update(pad_masks)
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
if delta_timestamps is None:
|
||||
return 1, 1
|
||||
if len(self.meta.video_keys) > 0:
|
||||
ep_idx = int(item["episode_index"])
|
||||
current_ts = float(item["timestamp"])
|
||||
# Per-camera episode-local bounds [0, duration]: out-of-episode deltas pad instead of
|
||||
# decoding against a neighbouring episode sharing the same video file.
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
0.0,
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
|
||||
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if not dynamic_bounds:
|
||||
# Fix the windows
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.meta.camera_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
item.update(video_frames)
|
||||
if self.delta_indices is not None:
|
||||
item.update(
|
||||
self._get_video_frame_padding_mask(video_frames, query_timestamps, original_timestamps)
|
||||
)
|
||||
|
||||
return lookback, lookahead
|
||||
item["task"] = self.meta.tasks.iloc[int(item["task_index"])].name
|
||||
return item
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Set the epoch the next ``__iter__`` will use (reshuffles the native pipeline)."""
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Native `datasets` stream state. Exact contract with ``num_workers=0``; with DataLoader
|
||||
workers use ``torchdata.stateful_dataloader.StatefulDataLoader`` (it checkpoints each
|
||||
worker's copy through this protocol). Samples in the shuffle buffers are skipped on
|
||||
resume (never repeated), bounded by the pool + frame buffer sizes.
|
||||
"""
|
||||
return {"pipeline": self._pipeline.state_dict(), "epoch": self._in_flight_epoch}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
# Resume continues inside the recorded epoch: the next __iter__ replays that epoch's
|
||||
# shuffle order from the restored stream position, then advances normally.
|
||||
self._epoch = int(state_dict.get("epoch", 0))
|
||||
self._pipeline.load_state_dict(state_dict["pipeline"])
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (
|
||||
start_ts + torch.tensor(indices[key]) / self.fps
|
||||
).tolist() # NOTE: why not delta_timestamps directly?
|
||||
key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist()
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
@@ -463,65 +471,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
@@ -552,11 +501,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
rel_path = str(self.meta.get_video_file_path(ep_idx, video_key))
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
query_ts,
|
||||
shifted_query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
@@ -566,116 +524,6 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return item
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
)
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
last_successful_frame = past_item[key]
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
last_successful_frame = future_item[key]
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
@@ -273,7 +273,11 @@ class VideoDecoderCache:
|
||||
self._cache.move_to_end(video_path)
|
||||
return entry[0]
|
||||
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
except Exception:
|
||||
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,7 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -49,6 +50,7 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -66,6 +68,8 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
|
||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,10 +56,14 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
@@ -111,6 +112,18 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
|
||||
@@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -384,14 +387,21 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
@@ -416,9 +426,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
|
||||
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
|
||||
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
|
||||
# data while decoding all of it). Batches are moved to the device manually in the loop below.
|
||||
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
|
||||
else:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -458,6 +475,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
if cfg.dataset.streaming:
|
||||
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
|
||||
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
|
||||
@@ -114,6 +114,30 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -288,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
||||
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
||||
check = float(left) == float(right)
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -0,0 +1,314 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
|
||||
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=2,
|
||||
frame_shuffle_buffer_size=8,
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
consumed = [int(next(it)["index"]) for _ in range(30)]
|
||||
state = ds.state_dict()
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
rest = [int(frame["index"]) for frame in resumed_ds]
|
||||
|
||||
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
|
||||
# in-flight buffer contents are skipped on resume (documented datasets behavior):
|
||||
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
|
||||
covered = len(set(consumed) | set(rest))
|
||||
max_in_flight = 2 * 30 + 8
|
||||
assert covered >= total_frames - max_in_flight
|
||||
assert covered + len(consumed) >= total_frames - max_in_flight
|
||||
|
||||
|
||||
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
|
||||
import datasets as hf_datasets
|
||||
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
assert EpisodicStrategyConfig().type == "episodic"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategy,
|
||||
EpisodicStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
|
||||
@@ -1084,8 +1084,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1102,10 +1102,6 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1764,7 +1760,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1772,14 +1768,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.13"
|
||||
version = "0.1.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1789,9 +1785,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1881,7 +1877,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1902,7 +1898,10 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -3075,7 +3074,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
@@ -3090,12 +3089,12 @@ requires-dist = [
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user