mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
feat(streaming): episode-pool iteration with decode-on-exit, video prefetch, and exact resume
Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.
- Prefetch-on-admit: when streaming from a remote source, each pooled
episode's video files download to a local cache in the background
(refcounted, since v3 packs several episodes per file; deleted on
eviction), so decode-on-exit reads local bytes instead of paying
network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
{batches_consumed, batch_size}; each worker re-derives its own skip
from the DataLoader's round-robin batch assignment and replays
tabular-only (no decode). Exact within an epoch, works with
num_workers > 0, and the same state file serves every rank. Replaces
the per-shard HF state_dict approach, which lived in worker processes
and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
shard); runtime warnings for non-divisible world sizes (datasets
degrades to read-everything splitting) and workers left without
shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
warning); decoder cache sized to the pool working set, capped at 128.
Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -62,7 +62,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--buffer_size", type=int, default=2000)
|
||||
parser.add_argument("--episode_pool_size", type=int, default=64)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--video_decode_device",
|
||||
@@ -86,7 +86,7 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
|
||||
root=args.root,
|
||||
data_files_root=args.data_files_root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
video_decode_device=args.video_decode_device,
|
||||
tolerance_s=1e-3,
|
||||
@@ -172,7 +172,7 @@ def main() -> None:
|
||||
"mode": args.mode,
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"buffer_size": args.buffer_size,
|
||||
"episode_pool_size": args.episode_pool_size,
|
||||
"num_cameras": len(meta.video_keys),
|
||||
"fps": meta.fps,
|
||||
"device": str(device),
|
||||
|
||||
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
|
||||
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
@@ -57,7 +57,10 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
|
||||
"--episode_pool_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Whole episodes open per consumer (randomness knob).",
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
||||
@@ -78,7 +81,7 @@ def make_dataloader(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
@@ -121,13 +124,13 @@ def main() -> None:
|
||||
# of it). Batches are moved to the device manually in the loop.
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
|
||||
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
|
||||
# checkpoint this script wrote itself.
|
||||
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
||||
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
||||
# worker re-derives its own skip. Same file works for every rank.
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
|
||||
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
@@ -157,8 +160,11 @@ def main() -> None:
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
|
||||
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
|
||||
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
||||
torch.save(
|
||||
{"batches_consumed": step, "batch_size": args.batch_size},
|
||||
ckpt / "dataset_state.pt",
|
||||
)
|
||||
if model is not None:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ for MODE in single sarm; do
|
||||
--mode $MODE \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--episode_pool_size 64 \
|
||||
--num_batches 300 \
|
||||
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
|
||||
done
|
||||
|
||||
@@ -83,7 +83,7 @@ for SOURCE in $SOURCES; do
|
||||
$RUN benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id $REPO_ID $ROOTFLAG \
|
||||
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --episode_pool_size $B \
|
||||
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
|
||||
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
|
||||
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
|
||||
|
||||
@@ -42,7 +42,7 @@ accelerate launch \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--episode_pool_size 64 \
|
||||
--steps 200000 \
|
||||
--save_freq 2000 \
|
||||
--log_freq 50
|
||||
|
||||
@@ -39,9 +39,10 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
|
||||
# of host RAM. Ignored when streaming is False.
|
||||
streaming_buffer_size: int = 1000
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 64
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
buffer_size=cfg.dataset.streaming_buffer_size,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -75,7 +76,7 @@ def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=8,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -101,7 +102,7 @@ def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
@@ -128,7 +129,7 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=1,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
@@ -147,8 +148,8 @@ def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_datas
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
|
||||
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
|
||||
def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory):
|
||||
"""Resume replays the deterministic stream and continues at the exact sample."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
@@ -157,27 +158,93 @@ def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_f
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=3,
|
||||
max_num_shards=1,
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
stop_after = 40
|
||||
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
|
||||
state = ds.state_dict()
|
||||
assert set(state) == {"shards", "exhausted", "rng"}
|
||||
full_epoch = _stream_indices(fresh_ds())
|
||||
assert sorted(full_epoch) == list(range(total_frames))
|
||||
|
||||
batches_consumed, batch_size = 5, 4 # 20 samples in
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size})
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
# Resume continues rather than replaying: the full first pass is not re-yielded.
|
||||
assert len(resumed) < total_frames
|
||||
overlap = set(seen_before) & set(resumed)
|
||||
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
|
||||
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
|
||||
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
|
||||
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
|
||||
assert resumed == full_epoch[batches_consumed * batch_size :], (
|
||||
"fast-forward resume did not continue at the exact sample"
|
||||
)
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_video_prefetcher_refcounted_lifecycle(tmp_path):
|
||||
from lerobot.datasets.streaming_dataset import _VideoPrefetcher
|
||||
|
||||
remote = tmp_path / "remote"
|
||||
(remote / "videos").mkdir(parents=True)
|
||||
payload = b"x" * 1024
|
||||
(remote / "videos" / "a.mp4").write_bytes(payload)
|
||||
|
||||
prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1)
|
||||
prefetcher.acquire("videos/a.mp4")
|
||||
prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file
|
||||
local = prefetcher.wait_local("videos/a.mp4")
|
||||
assert local is not None and local.read_bytes() == payload
|
||||
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert local.exists(), "file deleted while still referenced"
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert not local.exists(), "file not deleted at refcount zero"
|
||||
prefetcher.shutdown()
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
@@ -187,7 +254,7 @@ def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
@@ -217,7 +284,7 @@ def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypa
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
@@ -239,12 +306,12 @@ def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
|
||||
Reference in New Issue
Block a user