mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
894fc6bfb5
The custom episode pool becomes a pure `datasets` pipeline:
split_dataset_by_node -> batch(by_column="episode_index")
-> shuffle(buffer=episode_pool_size) # episode pool
-> map(explode + exact delta windows) # episode -> frames
-> shuffle(buffer=frame_shuffle_buffer_size) # frame interleave
and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.
Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.
Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.
The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.
Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
330 lines
11 KiB
Python
330 lines
11 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import pytest
|
|
import torch
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
|
from lerobot.datasets.utils import safe_shard
|
|
from lerobot.utils.constants import ACTION
|
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
|
|
|
|
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
|
"""Test if are correctly accessed"""
|
|
ds_num_frames = 400
|
|
ds_num_episodes = 10
|
|
buffer_size = 100
|
|
|
|
local_path = tmp_path / "test"
|
|
repo_id = f"{DUMMY_REPO_ID}"
|
|
|
|
ds = lerobot_dataset_factory(
|
|
root=local_path,
|
|
repo_id=repo_id,
|
|
total_episodes=ds_num_episodes,
|
|
total_frames=ds_num_frames,
|
|
)
|
|
|
|
streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size))
|
|
|
|
key_checks = []
|
|
for _ in range(ds_num_frames):
|
|
streaming_frame = next(streaming_ds)
|
|
frame_idx = streaming_frame["index"]
|
|
target_frame = ds[frame_idx]
|
|
|
|
for key in streaming_frame:
|
|
left = streaming_frame[key]
|
|
right = target_frame[key]
|
|
|
|
if isinstance(left, str):
|
|
check = left == right
|
|
|
|
elif isinstance(left, torch.Tensor):
|
|
check = torch.allclose(left, right) and left.shape == right.shape
|
|
|
|
elif isinstance(left, float):
|
|
check = left == right.item() # right is a torch.Tensor
|
|
|
|
key_checks.append((key, check))
|
|
|
|
assert all(t[1] for t in key_checks), (
|
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shuffle",
|
|
[False, True],
|
|
)
|
|
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
|
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
|
ds_num_frames = 400
|
|
ds_num_episodes = 10
|
|
seed = 42
|
|
n_epochs = 3
|
|
|
|
local_path = tmp_path / "test"
|
|
repo_id = f"{DUMMY_REPO_ID}"
|
|
|
|
lerobot_dataset_factory(
|
|
root=local_path,
|
|
repo_id=repo_id,
|
|
total_episodes=ds_num_episodes,
|
|
total_frames=ds_num_frames,
|
|
)
|
|
|
|
streaming_ds = StreamingLeRobotDataset(
|
|
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
|
)
|
|
|
|
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
|
for epoch_indices in epochs:
|
|
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
|
if shuffle:
|
|
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
|
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
|
else:
|
|
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"shuffle",
|
|
[False, True],
|
|
)
|
|
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
|
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
|
ds_num_frames = 100
|
|
ds_num_episodes = 10
|
|
seed = 42
|
|
data_file_size_mb = 0.001
|
|
chunks_size = 1
|
|
|
|
local_path = tmp_path / "test"
|
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
|
|
|
lerobot_dataset_factory(
|
|
root=local_path,
|
|
repo_id=repo_id,
|
|
total_episodes=ds_num_episodes,
|
|
total_frames=ds_num_frames,
|
|
data_files_size_in_mb=data_file_size_mb,
|
|
chunks_size=chunks_size,
|
|
)
|
|
|
|
def make_ds():
|
|
return StreamingLeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=local_path,
|
|
episode_pool_size=3,
|
|
seed=seed,
|
|
shuffle=shuffle,
|
|
max_num_shards=4,
|
|
)
|
|
|
|
first = [int(frame["index"]) for frame in make_ds()]
|
|
again = [int(frame["index"]) for frame in make_ds()]
|
|
|
|
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
|
assert first == again, "same seed must reproduce the same order"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"state_deltas, action_deltas",
|
|
[
|
|
([-1, -0.5, -0.20, 0], [0, 1, 2, 3]),
|
|
([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
|
([-2, -1, -0.5, 0], [0, 1, 2, 3]),
|
|
([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
|
],
|
|
)
|
|
def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas):
|
|
ds_num_frames = 500
|
|
ds_num_episodes = 10
|
|
buffer_size = 100
|
|
|
|
seed = 42
|
|
|
|
local_path = tmp_path / "test"
|
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
|
camera_key = "phone"
|
|
|
|
delta_timestamps = {
|
|
camera_key: state_deltas,
|
|
"state": state_deltas,
|
|
ACTION: action_deltas,
|
|
}
|
|
|
|
ds = lerobot_dataset_factory(
|
|
root=local_path,
|
|
repo_id=repo_id,
|
|
total_episodes=ds_num_episodes,
|
|
total_frames=ds_num_frames,
|
|
delta_timestamps=delta_timestamps,
|
|
)
|
|
|
|
streaming_ds = iter(
|
|
StreamingLeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=local_path,
|
|
buffer_size=buffer_size,
|
|
seed=seed,
|
|
shuffle=False,
|
|
delta_timestamps=delta_timestamps,
|
|
)
|
|
)
|
|
|
|
for i in range(ds_num_frames):
|
|
streaming_frame = next(streaming_ds)
|
|
frame_idx = streaming_frame["index"]
|
|
target_frame = ds[frame_idx]
|
|
|
|
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
|
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
|
)
|
|
|
|
key_checks = []
|
|
for key in streaming_frame:
|
|
left = streaming_frame[key]
|
|
right = target_frame[key]
|
|
|
|
if isinstance(left, str):
|
|
check = left == right
|
|
|
|
elif isinstance(left, torch.Tensor):
|
|
if (
|
|
key not in ds.meta.camera_keys
|
|
and "is_pad" not in key
|
|
and f"{key}_is_pad" in streaming_frame
|
|
):
|
|
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
|
left = left[~streaming_frame[f"{key}_is_pad"]]
|
|
right = right[~target_frame[f"{key}_is_pad"]]
|
|
|
|
check = torch.allclose(left, right) and left.shape == right.shape
|
|
|
|
else:
|
|
# Scalar numerics: streaming yields python floats/ints where map-style yields
|
|
# 0-dim tensors (long-standing accepted difference). Compare by value.
|
|
check = float(left) == float(right)
|
|
|
|
key_checks.append((key, check))
|
|
|
|
assert all(t[1] for t in key_checks), (
|
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"state_deltas, action_deltas",
|
|
[
|
|
([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]),
|
|
([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
|
([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]),
|
|
([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
|
],
|
|
)
|
|
def test_frames_with_delta_consistency_with_shards(
|
|
tmp_path, lerobot_dataset_factory, state_deltas, action_deltas
|
|
):
|
|
ds_num_frames = 100
|
|
ds_num_episodes = 10
|
|
buffer_size = 10
|
|
data_file_size_mb = 0.001
|
|
chunks_size = 1
|
|
|
|
seed = 42
|
|
|
|
local_path = tmp_path / "test"
|
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
|
camera_key = "phone"
|
|
|
|
delta_timestamps = {
|
|
camera_key: state_deltas,
|
|
"state": state_deltas,
|
|
ACTION: action_deltas,
|
|
}
|
|
|
|
ds = lerobot_dataset_factory(
|
|
root=local_path,
|
|
repo_id=repo_id,
|
|
total_episodes=ds_num_episodes,
|
|
total_frames=ds_num_frames,
|
|
delta_timestamps=delta_timestamps,
|
|
data_files_size_in_mb=data_file_size_mb,
|
|
chunks_size=chunks_size,
|
|
)
|
|
streaming_ds = StreamingLeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=local_path,
|
|
buffer_size=buffer_size,
|
|
seed=seed,
|
|
shuffle=False,
|
|
delta_timestamps=delta_timestamps,
|
|
max_num_shards=4,
|
|
)
|
|
|
|
iter(streaming_ds)
|
|
|
|
num_shards = 4
|
|
shards_indices = []
|
|
for shard_idx in range(num_shards):
|
|
shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards)
|
|
shard_indices = [item["index"] for item in shard]
|
|
shards_indices.append(shard_indices)
|
|
|
|
streaming_ds = iter(streaming_ds)
|
|
|
|
for i in range(ds_num_frames):
|
|
streaming_frame = next(streaming_ds)
|
|
frame_idx = streaming_frame["index"]
|
|
target_frame = ds[frame_idx]
|
|
|
|
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
|
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
|
)
|
|
|
|
key_checks = []
|
|
for key in streaming_frame:
|
|
left = streaming_frame[key]
|
|
right = target_frame[key]
|
|
|
|
if isinstance(left, str):
|
|
check = left == right
|
|
|
|
elif isinstance(left, torch.Tensor):
|
|
if (
|
|
key not in ds.meta.camera_keys
|
|
and "is_pad" not in key
|
|
and f"{key}_is_pad" in streaming_frame
|
|
):
|
|
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
|
left = left[~streaming_frame[f"{key}_is_pad"]]
|
|
right = right[~target_frame[f"{key}_is_pad"]]
|
|
|
|
check = torch.allclose(left, right) and left.shape == right.shape
|
|
|
|
elif isinstance(left, float):
|
|
check = left == right.item() # right is a torch.Tensor
|
|
|
|
key_checks.append((key, check))
|
|
|
|
assert all(t[1] for t in key_checks), (
|
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
|
)
|