#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import safe_shard from lerobot.utils.constants import ACTION from tests.fixtures.constants import DUMMY_REPO_ID def test_single_frame_consistency(tmp_path, lerobot_dataset_factory): """Test if are correctly accessed""" ds_num_frames = 400 ds_num_episodes = 10 buffer_size = 100 local_path = tmp_path / "test" repo_id = f"{DUMMY_REPO_ID}" ds = lerobot_dataset_factory( root=local_path, repo_id=repo_id, total_episodes=ds_num_episodes, total_frames=ds_num_frames, ) streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size)) key_checks = [] for _ in range(ds_num_frames): streaming_frame = next(streaming_ds) frame_idx = streaming_frame["index"] target_frame = ds[frame_idx] for key in streaming_frame: left = streaming_frame[key] right = target_frame[key] if isinstance(left, str): check = left == right elif isinstance(left, torch.Tensor): check = torch.allclose(left, right) and left.shape == right.shape elif isinstance(left, float): check = left == right.item() # right is a torch.Tensor key_checks.append((key, check)) assert all(t[1] for t in key_checks), ( f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})" ) @pytest.mark.parametrize( "shuffle", [False, True], ) def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): """Each epoch covers every frame exactly once; shuffle reshuffles across epochs.""" ds_num_frames = 400 ds_num_episodes = 10 seed = 42 n_epochs = 3 local_path = tmp_path / "test" repo_id = f"{DUMMY_REPO_ID}" lerobot_dataset_factory( root=local_path, repo_id=repo_id, total_episodes=ds_num_episodes, total_frames=ds_num_frames, ) streaming_ds = StreamingLeRobotDataset( repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle ) epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)] for epoch_indices in epochs: assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once" if shuffle: assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs" assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order" else: assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order" @pytest.mark.parametrize( "shuffle", [False, True], ) def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle): """Multi-shard streams keep exactly-once coverage and deterministic per-seed order.""" ds_num_frames = 100 ds_num_episodes = 10 seed = 42 data_file_size_mb = 0.001 chunks_size = 1 local_path = tmp_path / "test" repo_id = f"{DUMMY_REPO_ID}-ciao" lerobot_dataset_factory( root=local_path, repo_id=repo_id, total_episodes=ds_num_episodes, total_frames=ds_num_frames, data_files_size_in_mb=data_file_size_mb, chunks_size=chunks_size, ) def make_ds(): return StreamingLeRobotDataset( repo_id=repo_id, root=local_path, episode_pool_size=3, seed=seed, shuffle=shuffle, max_num_shards=4, ) first = [int(frame["index"]) for frame in make_ds()] again = [int(frame["index"]) for frame in make_ds()] assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once" assert first == again, "same seed must reproduce the same order" @pytest.mark.parametrize( "state_deltas, action_deltas", [ ([-1, -0.5, -0.20, 0], [0, 1, 2, 3]), ([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), ([-2, -1, -0.5, 0], [0, 1, 2, 3]), ([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), ], ) def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas): ds_num_frames = 500 ds_num_episodes = 10 buffer_size = 100 seed = 42 local_path = tmp_path / "test" repo_id = f"{DUMMY_REPO_ID}-ciao" camera_key = "phone" delta_timestamps = { camera_key: state_deltas, "state": state_deltas, ACTION: action_deltas, } ds = lerobot_dataset_factory( root=local_path, repo_id=repo_id, total_episodes=ds_num_episodes, total_frames=ds_num_frames, delta_timestamps=delta_timestamps, ) streaming_ds = iter( StreamingLeRobotDataset( repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=False, delta_timestamps=delta_timestamps, ) ) for i in range(ds_num_frames): streaming_frame = next(streaming_ds) frame_idx = streaming_frame["index"] target_frame = ds[frame_idx] assert set(streaming_frame.keys()) == set(target_frame.keys()), ( f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" ) key_checks = [] for key in streaming_frame: left = streaming_frame[key] right = target_frame[key] if isinstance(left, str): check = left == right elif isinstance(left, torch.Tensor): if ( key not in ds.meta.camera_keys and "is_pad" not in key and f"{key}_is_pad" in streaming_frame ): # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting left = left[~streaming_frame[f"{key}_is_pad"]] right = right[~target_frame[f"{key}_is_pad"]] check = torch.allclose(left, right) and left.shape == right.shape else: # Scalar numerics: streaming yields python floats/ints where map-style yields # 0-dim tensors (long-standing accepted difference). Compare by value. check = float(left) == float(right) key_checks.append((key, check)) assert all(t[1] for t in key_checks), ( f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" ) @pytest.mark.parametrize( "state_deltas, action_deltas", [ ([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]), ([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), ([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]), ([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), ], ) def test_frames_with_delta_consistency_with_shards( tmp_path, lerobot_dataset_factory, state_deltas, action_deltas ): ds_num_frames = 100 ds_num_episodes = 10 buffer_size = 10 data_file_size_mb = 0.001 chunks_size = 1 seed = 42 local_path = tmp_path / "test" repo_id = f"{DUMMY_REPO_ID}-ciao" camera_key = "phone" delta_timestamps = { camera_key: state_deltas, "state": state_deltas, ACTION: action_deltas, } ds = lerobot_dataset_factory( root=local_path, repo_id=repo_id, total_episodes=ds_num_episodes, total_frames=ds_num_frames, delta_timestamps=delta_timestamps, data_files_size_in_mb=data_file_size_mb, chunks_size=chunks_size, ) streaming_ds = StreamingLeRobotDataset( repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=False, delta_timestamps=delta_timestamps, max_num_shards=4, ) iter(streaming_ds) num_shards = 4 shards_indices = [] for shard_idx in range(num_shards): shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards) shard_indices = [item["index"] for item in shard] shards_indices.append(shard_indices) streaming_ds = iter(streaming_ds) for i in range(ds_num_frames): streaming_frame = next(streaming_ds) frame_idx = streaming_frame["index"] target_frame = ds[frame_idx] assert set(streaming_frame.keys()) == set(target_frame.keys()), ( f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" ) key_checks = [] for key in streaming_frame: left = streaming_frame[key] right = target_frame[key] if isinstance(left, str): check = left == right elif isinstance(left, torch.Tensor): if ( key not in ds.meta.camera_keys and "is_pad" not in key and f"{key}_is_pad" in streaming_frame ): # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting left = left[~streaming_frame[f"{key}_is_pad"]] right = right[~target_frame[f"{key}_is_pad"]] check = torch.allclose(left, right) and left.shape == right.shape elif isinstance(left, float): check = left == right.item() # right is a torch.Tensor key_checks.append((key, check)) assert all(t[1] for t in key_checks), ( f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" )