feat(streaming): native datasets-5 episode batching and worker-split suppression

Allow datasets 5.x (pin >=4.7,<6; lockfile moves to 5.0.0) and use its
Arrow-native batch(by_column="episode_index") (huggingface/datasets#8194
sibling, #8172) for episode admission when available - one Arrow
accumulation per episode instead of one Python dict per row - with the
existing row loop as the 4.x fallback. A parity test asserts both paths
group identically.

Also fixes a latent worker bug this surfaced: `datasets` detects torch
DataLoader workers and re-splits its shards internally (_iter_pytorch),
on top of our explicit per-worker shard assignment. That second split
silently drops data whenever a per-worker stream has fewer internal
shards than there are workers (masked so far by single-file test
fixtures), and on datasets 5.0 it crashes by_column batching outright.
The worker context is now hidden from `datasets` while draining streams
we already partitioned (process-local patch, restored on exit).

The multi-shard shuffle buffer (huggingface/datasets#8194) is
intentionally NOT used: frame-level shuffling upstream of episode
grouping would fragment episodes and break delta windows. Its threaded
multi-source prefetch idea remains a follow-up for episode admission if
fetch timings warrant it.

Verified on both datasets 4.8.5 (fallback) and 5.0.0 (native): 27/27
streaming tests each; full datasets suite 469 passed under 5.0.0.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 16:10:53 +02:00
parent 79b547de32
commit a164bb97bd
4 changed files with 99 additions and 34 deletions
+1 -1
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<5.0.0",
"datasets>=4.7.0,<6.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
+69 -29
View File
@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
@@ -46,6 +48,10 @@ from .video_utils import (
logger = logging.getLogger(__name__)
# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation,
# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop.
_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters
_MASK_64 = (1 << 64) - 1
@@ -60,6 +66,24 @@ def _mix64(x: int) -> int:
return x
@contextlib.contextmanager
def _suppress_hf_worker_split():
"""Hide the torch DataLoader worker context from `datasets` while we drain its streams.
`datasets` detects torch workers and re-splits its shards across them internally
(`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second
split silently drops data whenever a per-worker stream has fewer internal shards than there
are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local
to this DataLoader worker process and restored on exit.
"""
original = torch.utils.data.get_worker_info
torch.utils.data.get_worker_info = lambda: None
try:
yield
finally:
torch.utils.data.get_worker_info = original
class _PooledEpisode:
"""A fully-loaded episode's tabular rows plus emission bookkeeping."""
@@ -410,7 +434,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
@staticmethod
def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]:
"""Yield (episode_index, rows) for each complete episode of a shard stream."""
"""Yield (episode_index, rows) for each complete episode of a shard stream.
On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)``
(one accumulation per episode instead of one Python dict per row); older versions
use the equivalent row loop.
"""
if _HAS_BATCH_BY_COLUMN:
for batch in shard.batch(by_column="episode_index"):
keys = list(batch.keys())
num_rows = len(batch["episode_index"])
rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)]
yield int(batch["episode_index"][0]), rows
return
rows: list[dict] = []
current: int | None = None
for item in shard:
@@ -496,37 +532,41 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
admitted += 1
return admitted
worker_split_guard = (
_suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext()
)
try:
admit()
while pool:
if self.shuffle:
# Uniform draw over every remaining frame in the pool: pick the episode by
# cumulative remaining count, then a random remaining position (swap-pop).
draw = int(rng.integers(total_remaining))
for episode in pool:
if draw < len(episode.remaining):
break
draw -= len(episode.remaining)
pick = int(rng.integers(len(episode.remaining)))
frame_pos = episode.remaining[pick]
episode.remaining[pick] = episode.remaining[-1]
episode.remaining.pop()
else:
episode = pool[0]
frame_pos = episode.remaining.pop(0)
total_remaining -= 1
with worker_split_guard:
admit()
while pool:
if self.shuffle:
# Uniform draw over every remaining frame in the pool: pick the episode by
# cumulative remaining count, then a random remaining position (swap-pop).
draw = int(rng.integers(total_remaining))
for episode in pool:
if draw < len(episode.remaining):
break
draw -= len(episode.remaining)
pick = int(rng.integers(len(episode.remaining)))
frame_pos = episode.remaining[pick]
episode.remaining[pick] = episode.remaining[-1]
episode.remaining.pop()
else:
episode = pool[0]
frame_pos = episode.remaining.pop(0)
total_remaining -= 1
if self._ff_remaining > 0:
self._ff_remaining -= 1
else:
yield self._make_pool_sample(episode, frame_pos)
if self._ff_remaining > 0:
self._ff_remaining -= 1
else:
yield self._make_pool_sample(episode, frame_pos)
if not episode.remaining:
pool.remove(episode)
if prefetcher is not None:
for rel in episode.video_rel_paths:
prefetcher.release(rel)
admit()
if not episode.remaining:
pool.remove(episode)
if prefetcher is not None:
for rel in episode.video_rel_paths:
prefetcher.release(rel)
admit()
finally:
if prefetcher is not None:
prefetcher.shutdown()
+25
View File
@@ -353,3 +353,28 @@ def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_f
assert resumed == full[samples_consumed:], (
"fast-forward resume with DataLoader workers did not continue at the exact sample"
)
def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch):
"""The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-grouping"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1)
def episode_signature(use_native):
monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native)
return [
(ep_idx, [int(row["index"]) for row in rows])
for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset)
]
fallback = episode_signature(False)
assert len(fallback) == 5
if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str(
type(ds.hf_dataset).batch.__doc__ or ""
): # datasets < 5: only the fallback path exists
return
native = episode_signature(True)
assert native == fallback
Generated
+4 -4
View File
@@ -1084,7 +1084,7 @@ wheels = [
[[package]]
name = "datasets"
version = "4.8.5"
version = "5.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dill" },
@@ -1102,9 +1102,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d9/85/ce4f780c32f7e36d71257f1c27e8ba898ebe379cb54f211f5f2013f2c219/datasets-5.0.0.tar.gz", hash = "sha256:83dbbbdb07a33b82192b8c419deb18739b138ee2ce1a322d55ce6b100954ec1a", size = 631708, upload-time = "2026-06-05T13:18:26.124Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
{ url = "https://files.pythonhosted.org/packages/05/66/73034ad30b59f13439b75e620989dacba4c047256e358ba7c2e9ec98ea22/datasets-5.0.0-py3-none-any.whl", hash = "sha256:7dd34927a0fd7046e98aad5cb9430e699c373238a15befa7b9bf22b991a7fee6", size = 555084, upload-time = "2026-06-05T13:18:24.435Z" },
]
[[package]]
@@ -3078,7 +3078,7 @@ requires-dist = [
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<6.0.0" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },