mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
feat(streaming): native datasets-5 episode batching and worker-split suppression
Allow datasets 5.x (pin >=4.7,<6; lockfile moves to 5.0.0) and use its Arrow-native batch(by_column="episode_index") (huggingface/datasets#8194 sibling, #8172) for episode admission when available - one Arrow accumulation per episode instead of one Python dict per row - with the existing row loop as the 4.x fallback. A parity test asserts both paths group identically. Also fixes a latent worker bug this surfaced: `datasets` detects torch DataLoader workers and re-splits its shards internally (_iter_pytorch), on top of our explicit per-worker shard assignment. That second split silently drops data whenever a per-worker stream has fewer internal shards than there are workers (masked so far by single-file test fixtures), and on datasets 5.0 it crashes by_column batching outright. The worker context is now hidden from `datasets` while draining streams we already partitioned (process-local patch, restored on exit). The multi-shard shuffle buffer (huggingface/datasets#8194) is intentionally NOT used: frame-level shuffling upstream of episode grouping would fragment episodes and break delta windows. Its threaded multi-source prefetch idea remains a follow-up for episode admission if fetch timings warrant it. Verified on both datasets 4.8.5 (fallback) and 5.0.0 (native): 27/27 streaming tests each; full datasets suite 469 passed under 5.0.0. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"datasets>=4.7.0,<6.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -46,6 +48,10 @@ from .video_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# datasets >= 5 groups a stream into whole-episode batches natively (Arrow-side accumulation,
|
||||
# https://github.com/huggingface/datasets/pull/8172); older versions fall back to a Python row loop.
|
||||
_HAS_BATCH_BY_COLUMN = "by_column" in inspect.signature(datasets.IterableDataset.batch).parameters
|
||||
|
||||
_MASK_64 = (1 << 64) - 1
|
||||
|
||||
|
||||
@@ -60,6 +66,24 @@ def _mix64(x: int) -> int:
|
||||
return x
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _suppress_hf_worker_split():
|
||||
"""Hide the torch DataLoader worker context from `datasets` while we drain its streams.
|
||||
|
||||
`datasets` detects torch workers and re-splits its shards across them internally
|
||||
(`_iter_pytorch`); this dataset already assigns disjoint shards per worker, so the second
|
||||
split silently drops data whenever a per-worker stream has fewer internal shards than there
|
||||
are workers — and on datasets 5.0 it also crashes `batch(by_column=...)`. The patch is local
|
||||
to this DataLoader worker process and restored on exit.
|
||||
"""
|
||||
original = torch.utils.data.get_worker_info
|
||||
torch.utils.data.get_worker_info = lambda: None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.utils.data.get_worker_info = original
|
||||
|
||||
|
||||
class _PooledEpisode:
|
||||
"""A fully-loaded episode's tabular rows plus emission bookkeeping."""
|
||||
|
||||
@@ -410,7 +434,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
@staticmethod
|
||||
def _iter_shard_episodes(shard: datasets.IterableDataset) -> Iterator[tuple[int, list[dict]]]:
|
||||
"""Yield (episode_index, rows) for each complete episode of a shard stream."""
|
||||
"""Yield (episode_index, rows) for each complete episode of a shard stream.
|
||||
|
||||
On datasets >= 5 the grouping runs natively in Arrow via ``batch(by_column=...)``
|
||||
(one accumulation per episode instead of one Python dict per row); older versions
|
||||
use the equivalent row loop.
|
||||
"""
|
||||
if _HAS_BATCH_BY_COLUMN:
|
||||
for batch in shard.batch(by_column="episode_index"):
|
||||
keys = list(batch.keys())
|
||||
num_rows = len(batch["episode_index"])
|
||||
rows = [{key: batch[key][i] for key in keys} for i in range(num_rows)]
|
||||
yield int(batch["episode_index"][0]), rows
|
||||
return
|
||||
rows: list[dict] = []
|
||||
current: int | None = None
|
||||
for item in shard:
|
||||
@@ -496,37 +532,41 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
admitted += 1
|
||||
return admitted
|
||||
|
||||
worker_split_guard = (
|
||||
_suppress_hf_worker_split() if worker_info is not None else contextlib.nullcontext()
|
||||
)
|
||||
try:
|
||||
admit()
|
||||
while pool:
|
||||
if self.shuffle:
|
||||
# Uniform draw over every remaining frame in the pool: pick the episode by
|
||||
# cumulative remaining count, then a random remaining position (swap-pop).
|
||||
draw = int(rng.integers(total_remaining))
|
||||
for episode in pool:
|
||||
if draw < len(episode.remaining):
|
||||
break
|
||||
draw -= len(episode.remaining)
|
||||
pick = int(rng.integers(len(episode.remaining)))
|
||||
frame_pos = episode.remaining[pick]
|
||||
episode.remaining[pick] = episode.remaining[-1]
|
||||
episode.remaining.pop()
|
||||
else:
|
||||
episode = pool[0]
|
||||
frame_pos = episode.remaining.pop(0)
|
||||
total_remaining -= 1
|
||||
with worker_split_guard:
|
||||
admit()
|
||||
while pool:
|
||||
if self.shuffle:
|
||||
# Uniform draw over every remaining frame in the pool: pick the episode by
|
||||
# cumulative remaining count, then a random remaining position (swap-pop).
|
||||
draw = int(rng.integers(total_remaining))
|
||||
for episode in pool:
|
||||
if draw < len(episode.remaining):
|
||||
break
|
||||
draw -= len(episode.remaining)
|
||||
pick = int(rng.integers(len(episode.remaining)))
|
||||
frame_pos = episode.remaining[pick]
|
||||
episode.remaining[pick] = episode.remaining[-1]
|
||||
episode.remaining.pop()
|
||||
else:
|
||||
episode = pool[0]
|
||||
frame_pos = episode.remaining.pop(0)
|
||||
total_remaining -= 1
|
||||
|
||||
if self._ff_remaining > 0:
|
||||
self._ff_remaining -= 1
|
||||
else:
|
||||
yield self._make_pool_sample(episode, frame_pos)
|
||||
if self._ff_remaining > 0:
|
||||
self._ff_remaining -= 1
|
||||
else:
|
||||
yield self._make_pool_sample(episode, frame_pos)
|
||||
|
||||
if not episode.remaining:
|
||||
pool.remove(episode)
|
||||
if prefetcher is not None:
|
||||
for rel in episode.video_rel_paths:
|
||||
prefetcher.release(rel)
|
||||
admit()
|
||||
if not episode.remaining:
|
||||
pool.remove(episode)
|
||||
if prefetcher is not None:
|
||||
for rel in episode.video_rel_paths:
|
||||
prefetcher.release(rel)
|
||||
admit()
|
||||
finally:
|
||||
if prefetcher is not None:
|
||||
prefetcher.shutdown()
|
||||
|
||||
@@ -353,3 +353,28 @@ def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_f
|
||||
assert resumed == full[samples_consumed:], (
|
||||
"fast-forward resume with DataLoader workers did not continue at the exact sample"
|
||||
)
|
||||
|
||||
|
||||
def test_episode_grouping_native_and_fallback_agree(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""The datasets>=5 batch(by_column=...) path must group episodes identically to the row loop."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-grouping"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=100)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, max_num_shards=1)
|
||||
|
||||
def episode_signature(use_native):
|
||||
monkeypatch.setattr(sd, "_HAS_BATCH_BY_COLUMN", use_native)
|
||||
return [
|
||||
(ep_idx, [int(row["index"]) for row in rows])
|
||||
for ep_idx, rows in ds._iter_shard_episodes(ds.hf_dataset)
|
||||
]
|
||||
|
||||
fallback = episode_signature(False)
|
||||
assert len(fallback) == 5
|
||||
if not sd._HAS_BATCH_BY_COLUMN and "by_column" not in str(
|
||||
type(ds.hf_dataset).batch.__doc__ or ""
|
||||
): # datasets < 5: only the fallback path exists
|
||||
return
|
||||
native = episode_signature(True)
|
||||
assert native == fallback
|
||||
|
||||
@@ -1084,7 +1084,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.8.5"
|
||||
version = "5.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
@@ -1102,9 +1102,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d9/85/ce4f780c32f7e36d71257f1c27e8ba898ebe379cb54f211f5f2013f2c219/datasets-5.0.0.tar.gz", hash = "sha256:83dbbbdb07a33b82192b8c419deb18739b138ee2ce1a322d55ce6b100954ec1a", size = 631708, upload-time = "2026-06-05T13:18:26.124Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/66/73034ad30b59f13439b75e620989dacba4c047256e358ba7c2e9ec98ea22/datasets-5.0.0-py3-none-any.whl", hash = "sha256:7dd34927a0fd7046e98aad5cb9430e699c373238a15befa7b9bf22b991a7fee6", size = 555084, upload-time = "2026-06-05T13:18:24.435Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3078,7 +3078,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<6.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user