From a164bb97bd53ec9d0b9a2f9d73a8a21b3c5901c5 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 16:10:53 +0200 Subject: [PATCH] feat(streaming): native datasets-5 episode batching and worker-split suppression Allow datasets 5.x (pin >=4.7,<6; lockfile moves to 5.0.0) and use its Arrow-native batch(by_column="episode_index") (huggingface/datasets#8194 sibling, #8172) for episode admission when available - one Arrow accumulation per episode instead of one Python dict per row - with the existing row loop as the 4.x fallback. A parity test asserts both paths group identically. Also fixes a latent worker bug this surfaced: `datasets` detects torch DataLoader workers and re-splits its shards internally (_iter_pytorch), on top of our explicit per-worker shard assignment. That second split silently drops data whenever a per-worker stream has fewer internal shards than there are workers (masked so far by single-file test fixtures), and on datasets 5.0 it crashes by_column batching outright. The worker context is now hidden from `datasets` while draining streams we already partitioned (process-local patch, restored on exit). The multi-shard shuffle buffer (huggingface/datasets#8194) is intentionally NOT used: frame-level shuffling upstream of episode grouping would fragment episodes and break delta windows. Its threaded multi-source prefetch idea remains a follow-up for episode admission if fetch timings warrant it. Verified on both datasets 4.8.5 (fallback) and 5.0.0 (native): 27/27 streaming tests each; full datasets suite 469 passed under 5.0.0. Co-Authored-By: Claude Fable 5 --- pyproject.toml | 2 +- src/lerobot/datasets/streaming_dataset.py | 98 ++++++++++++++++------- tests/datasets/test_streaming_native.py | 25 ++++++ uv.lock | 8 +- 4 files changed, 99 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f72cfa6dd..3a67cac28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ # ── Feature-scoped extras ────────────────────────────────── dataset = [ - "datasets>=4.7.0,<5.0.0", + "datasets>=4.7.0,<6.0.0", "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets "lerobot[av-dep]", diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 2f0209575..c03e19c61 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import inspect import logging import os import shutil @@ -46,6 +48,10 @@ from .video_utils import ( logger = logging.getLogger(__name__) +# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation, +# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop. +_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters + _MASK_64 = (1 << 64) - 1 @@ -60,6 +66,24 @@ def _mix64(x: int) -> int: return x +@contextlib.contextmanager +def _suppress_hf_worker_split(): + """Hide the torch DataLoader worker context from `datasets` while we drain its streams. + + `datasets` detects torch workers and re-splits its shards across them internally + (`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second + split silently drops data whenever a per-worker stream has fewer internal shards than there + are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local + to this DataLoader worker process and restored on exit. + """ + original = torch.utils.data.get_worker_info + torch.utils.data.get_worker_info = lambda: None + try: + yield + finally: + torch.utils.data.get_worker_info = original + + class _PooledEpisode: """A fully-loaded episode's tabular rows plus emission bookkeeping.""" @@ -410,7 +434,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): @staticmethod def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]: - """Yield (episode_index, rows) for each complete episode of a shard stream.""" + """Yield (episode_index, rows) for each complete episode of a shard stream. + + On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)`` + (one accumulation per episode instead of one Python dict per row); older versions + use the equivalent row loop. + """ + if _HAS_BATCH_BY_COLUMN: + for batch in shard.batch(by_column="episode_index"): + keys = list(batch.keys()) + num_rows = len(batch["episode_index"]) + rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)] + yield int(batch["episode_index"][0]), rows + return rows: list[dict] = [] current: int | None = None for item in shard: @@ -496,37 +532,41 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): admitted += 1 return admitted + worker_split_guard = ( + _suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext() + ) try: - admit() - while pool: - if self.shuffle: - # Uniform draw over every remaining frame in the pool: pick the episode by - # cumulative remaining count, then a random remaining position (swap-pop). - draw = int(rng.integers(total_remaining)) - for episode in pool: - if draw < len(episode.remaining): - break - draw -= len(episode.remaining) - pick = int(rng.integers(len(episode.remaining))) - frame_pos = episode.remaining[pick] - episode.remaining[pick] = episode.remaining[-1] - episode.remaining.pop() - else: - episode = pool[0] - frame_pos = episode.remaining.pop(0) - total_remaining -= 1 + with worker_split_guard: + admit() + while pool: + if self.shuffle: + # Uniform draw over every remaining frame in the pool: pick the episode by + # cumulative remaining count, then a random remaining position (swap-pop). + draw = int(rng.integers(total_remaining)) + for episode in pool: + if draw < len(episode.remaining): + break + draw -= len(episode.remaining) + pick = int(rng.integers(len(episode.remaining))) + frame_pos = episode.remaining[pick] + episode.remaining[pick] = episode.remaining[-1] + episode.remaining.pop() + else: + episode = pool[0] + frame_pos = episode.remaining.pop(0) + total_remaining -= 1 - if self._ff_remaining > 0: - self._ff_remaining -= 1 - else: - yield self._make_pool_sample(episode, frame_pos) + if self._ff_remaining > 0: + self._ff_remaining -= 1 + else: + yield self._make_pool_sample(episode, frame_pos) - if not episode.remaining: - pool.remove(episode) - if prefetcher is not None: - for rel in episode.video_rel_paths: - prefetcher.release(rel) - admit() + if not episode.remaining: + pool.remove(episode) + if prefetcher is not None: + for rel in episode.video_rel_paths: + prefetcher.release(rel) + admit() finally: if prefetcher is not None: prefetcher.shutdown() diff --git a/tests/datasets/test_streaming_native.py b/tests/datasets/test_streaming_native.py index 47cb9a7cd..f616bfe8d 100644 --- a/tests/datasets/test_streaming_native.py +++ b/tests/datasets/test_streaming_native.py @@ -353,3 +353,28 @@ def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_f assert resumed == full[samples_consumed:], ( "fast-forward resume with DataLoader workers did not continue at the exact sample" ) + + +def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch): + """The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop.""" + import lerobot.datasets.streaming_dataset as sd + + repo_id = f"{DUMMY_REPO_ID}-grouping" + _make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100) + ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1) + + def episode_signature(use_native): + monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native) + return [ + (ep_idx, [int(row["index"]) for row in rows]) + for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset) + ] + + fallback = episode_signature(False) + assert len(fallback) == 5 + if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str( + type(ds.hf_dataset).batch.__doc__ or "" + ): # datasets < 5: only the fallback path exists + return + native = episode_signature(True) + assert native == fallback diff --git a/uv.lock b/uv.lock index 3a7129dac..7ff9d2466 100644 --- a/uv.lock +++ b/uv.lock @@ -1084,7 +1084,7 @@ wheels = [ [[package]] name = "datasets" -version = "4.8.5" +version = "5.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dill" }, @@ -1102,9 +1102,9 @@ dependencies = [ { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/85/ce4f780c32f7e36d71257f1c27e8ba898ebe379cb54f211f5f2013f2c219/datasets-5.0.0.tar.gz", hash = "sha256:83dbbbdb07a33b82192b8c419deb18739b138ee2ce1a322d55ce6b100954ec1a", size = 631708, upload-time = "2026-06-05T13:18:26.124Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" }, + { url = "https://files.pythonhosted.org/packages/05/66/73034ad30b59f13439b75e620989dacba4c047256e358ba7c2e9ec98ea22/datasets-5.0.0-py3-none-any.whl", hash = "sha256:7dd34927a0fd7046e98aad5cb9430e699c373238a15befa7b9bf22b991a7fee6", size = 555084, upload-time = "2026-06-05T13:18:24.435Z" }, ] [[package]] @@ -3078,7 +3078,7 @@ requires-dist = [ { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" }, { name = "cmake", specifier = ">=3.29.0.1,<4.2.0" }, { name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" }, - { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" }, + { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<6.0.0" }, { name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" }, { name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" }, { name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },