feat(streaming): multinode example, dataloading benchmark, distributed smoke test

- examples/scaling/train_streaming_multinode.py: Accelerate-based distributed/
  resumable streaming training (no DistributedSampler; rank/world_size auto-resolved),
  checkpoints the dataset stream state, and supports a --dummy pure-dataloading path
  with throughput logging. SLURM launcher in slurm/train_streaming_robocasa.sh.
- benchmarks/streaming/benchmark_streaming.py: dummy-consumer dataloading benchmark
  (single / sarm frame modes) emitting frames/s/node, p50/p95/p99 sample latency,
  first-batch latency, and VideoDecoderCache reuse stats as JSON + CSV. SLURM launcher
  + README documenting the source/node/mode matrix and manual bucket prewarming.
- VideoDecoderCache: add hit/miss/eviction counters and a stats() method so the
  benchmark can surface decoder thrash (no new cache, no eviction-policy change).
- tests/datasets/test_streaming_distributed.py: accelerate-launch smoke test asserting
  per-rank disjointness; skips (does not false-pass) when <2 processes spawn.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-09 13:48:23 +02:00
parent d1fc8e298c
commit 68fa5d80b0
8 changed files with 608 additions and 4 deletions
@@ -0,0 +1,100 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
"""
import json
import shutil
import subprocess
import sys
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
from tests.fixtures.constants import DUMMY_REPO_ID
WORKER = """
import json, sys
from accelerate import PartialState
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
state = PartialState()
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
)
indices = [int(frame["index"]) for frame in ds]
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
json.dump(payload, f)
"""
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
total_frames = 160
repo_id = f"{DUMMY_REPO_ID}-acc"
root = tmp_path / "ds"
lerobot_dataset_factory(
root=root,
repo_id=repo_id,
total_episodes=8,
total_frames=total_frames,
use_videos=False,
data_files_size_in_mb=0.001,
chunks_size=1,
)
worker = tmp_path / "worker.py"
worker.write_text(WORKER)
out_dir = tmp_path / "out"
out_dir.mkdir()
cmd = [
"accelerate",
"launch",
"--num_processes=2",
"--num_machines=1",
"--mixed_precision=no",
"--dynamo_backend=no",
"--cpu",
str(worker),
str(root),
repo_id,
str(out_dir),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
)
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
rank_sets = [set(p["indices"]) for p in payloads]
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
+9 -4
View File
@@ -115,8 +115,12 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
"""
repo_id = f"{DUMMY_REPO_ID}-sarm"
# Two episodes of 200 frames each -> a +150-frame lookahead stays inside an episode for early frames.
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=2, total_frames=400)
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
episode_frames = 300
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
)
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
delta_timestamps = {ACTION: [0.0, horizon_s]}
@@ -130,11 +134,12 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
)
horizon_frames = int(round(horizon_s * ds.fps))
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
checked = 0
for frame in ds:
idx = int(frame["index"])
# Only assert on frames whose +horizon target is still inside the same episode.
if int(frame["episode_index"]) == 0 and idx + horizon_frames < 200:
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
if idx + horizon_frames < episode_frames:
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
)