mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
feat(streaming): random-episode admission via reshard() + multi-input-shard shuffle
Reshard parquet per row group (1 shard == 1 row group == 1 episode) and feed the episode-pool shuffle with max_buffer_input_shards so the pool is a uniform random sample of the corpus, independent of episodes-per-file. Add validate_row_groups guardrails (collapsed-row-group + distributed divisibility), require datasets>=5.0.0, make the test fixture write one row group per episode, and plumb max_buffer_input_shards through the dataloading benchmark. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -222,8 +222,12 @@ def build_dataset(scenario: str, args: argparse.Namespace):
|
||||
revision=MOLMO_REVISION,
|
||||
data_files_root=data_files_root,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
max_buffer_input_shards=args.max_buffer_input_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
|
||||
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
|
||||
validate_row_groups=False,
|
||||
)
|
||||
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
|
||||
|
||||
@@ -362,6 +366,9 @@ def run_scenario(scenario: str, args: argparse.Namespace) -> None:
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"episode_pool_size": None if is_map_style else args.episode_pool_size,
|
||||
"max_buffer_input_shards": None
|
||||
if is_map_style
|
||||
else (args.max_buffer_input_shards or args.episode_pool_size),
|
||||
**info,
|
||||
"num_cameras": num_cameras,
|
||||
"image_shape": image_shape,
|
||||
@@ -419,6 +426,8 @@ def submit_chain(args: argparse.Namespace) -> None:
|
||||
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
|
||||
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
|
||||
)
|
||||
if args.max_buffer_input_shards is not None:
|
||||
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
|
||||
if args.local_root:
|
||||
common += f" --local_root {args.local_root}"
|
||||
env_prefix = "export TOKENIZERS_PARALLELISM=false"
|
||||
@@ -491,6 +500,13 @@ def parse_args() -> argparse.Namespace:
|
||||
p.add_argument(
|
||||
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_buffer_input_shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Concurrently-live random episodes feeding the pool after reshard() "
|
||||
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
|
||||
)
|
||||
|
||||
+4
-3
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.7.0,<6.0.0",
|
||||
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -334,8 +334,9 @@ explicit = true
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation, fixed in datasets#8259 (merged, not yet released). Pin to the merge commit until the
|
||||
# next datasets release ships it, then drop this and bump the floor in `dependencies`.
|
||||
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next
|
||||
# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -51,9 +51,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
The tabular side is a pure `datasets` pipeline::
|
||||
|
||||
load_dataset(streaming=True) # parquet shards from the Hub / a bucket
|
||||
-> reshard() # 1 shard == 1 row group == 1 episode
|
||||
-> split_dataset_by_node(rank, world_size) # disjoint shards per rank
|
||||
-> batch(by_column="episode_index") # whole episodes
|
||||
-> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob)
|
||||
-> batch(by_column="episode_index") # whole episodes (one per shard)
|
||||
-> shuffle(episode_pool_size, max_buffer_input_shards) # K random episodes, global perm
|
||||
-> map(explode + exact delta windows) # episode -> frames, windows are exact
|
||||
-> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave
|
||||
|
||||
@@ -62,6 +63,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
string. DataLoader workers are split natively by `datasets` (disjoint shards per worker),
|
||||
and resume uses the native ``state_dict`` / ``load_state_dict``.
|
||||
|
||||
Random-episode admission (Plan B): the LeRobot writer stores one Parquet row group per
|
||||
episode, so ``datasets.IterableDataset.reshard()`` makes one shard == one episode (no new
|
||||
files; shards are (file, row_group) pairs). ``shuffle`` then permutes shard order globally and
|
||||
fills its buffer from ``max_buffer_input_shards`` shards concurrently, so the episode pool is a
|
||||
uniformly-random sample of the corpus regardless of how many episodes are packed per file.
|
||||
``max_buffer_input_shards`` is the number of concurrently-live random episodes; set it
|
||||
``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
|
||||
|
||||
Requirement: ONE ROW GROUP PER EPISODE. Recorded datasets satisfy this; bulk
|
||||
``df.to_parquet`` / ``push_to_hub`` / aggregate paths collapse row groups and are rejected at
|
||||
init (see ``validate_row_groups``). Old collapsed datasets still load fine for the map-style
|
||||
path; only this streaming random-episode path requires the invariant.
|
||||
|
||||
Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are
|
||||
exact slices of the resident episode with correct padding at episode boundaries.
|
||||
|
||||
@@ -97,6 +111,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
episode_pool_size: int | None = 1024,
|
||||
max_buffer_input_shards: int | None = None,
|
||||
frame_shuffle_buffer_size: int | None = None,
|
||||
buffer_size: int | None = None,
|
||||
max_num_shards: int | None = None,
|
||||
@@ -108,6 +123,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
validate_row_groups: bool = True,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -127,6 +143,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
across — the randomness knob. Larger mixes more episodes per batch (closer to
|
||||
map-style uniform) at the cost of cold-start latency and frame-buffer RAM.
|
||||
Defaults to 1024.
|
||||
max_buffer_input_shards (int | None, optional): Number of shards (== episodes, after
|
||||
``reshard()``) the episode-pool ``shuffle`` reads from concurrently — i.e. the count
|
||||
of concurrently-live random episodes feeding the pool from a global shard permutation.
|
||||
Set ``>= batch_size`` for the per-batch distinct-episode fraction to approach 1.
|
||||
Defaults to ``episode_pool_size``.
|
||||
frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the
|
||||
episode pool. Defaults to ``episode_pool_size x average episode length`` (capped),
|
||||
which matches the pool's mixing radius.
|
||||
@@ -147,6 +168,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
|
||||
trees (e.g. ``hf://buckets/<owner>/<name>``). When set, parquet and video bytes are read
|
||||
from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
validate_row_groups (bool, optional): When True (default), verify at init that the dataset
|
||||
stores one Parquet row group per episode (sampling data-file footers) and that
|
||||
``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear
|
||||
``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging);
|
||||
the divisibility check then downgrades to a warning.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -175,6 +201,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 1024
|
||||
self.max_buffer_input_shards = (
|
||||
max(1, max_buffer_input_shards) if max_buffer_input_shards else self.episode_pool_size
|
||||
)
|
||||
self.validate_row_groups = validate_row_groups
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
@@ -232,8 +262,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
|
||||
# Reshard Parquet per row group so 1 shard == 1 row group == 1 episode (the LeRobot writer
|
||||
# emits one row group per episode). This lets the episode-pool shuffle admit uniformly-random
|
||||
# episodes from a global shard permutation, independent of how many episodes are packed per file.
|
||||
if self.streaming:
|
||||
self.hf_dataset = self.hf_dataset.reshard()
|
||||
self.num_shards = self.hf_dataset.num_shards
|
||||
|
||||
if self.validate_row_groups and self.streaming:
|
||||
self._validate_row_groups_per_episode()
|
||||
|
||||
avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes)))
|
||||
self.frame_shuffle_buffer_size = (
|
||||
frame_shuffle_buffer_size
|
||||
@@ -283,22 +321,102 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _resolve_data_root(self) -> str:
|
||||
"""fsspec root that holds the bulk ``data/`` parquet tree (revision-qualified for the Hub)."""
|
||||
if self.data_files_root is not None:
|
||||
return self.data_files_root
|
||||
if self.streaming and not self.streaming_from_local:
|
||||
return f"hf://datasets/{self.repo_id}@{self.revision}"
|
||||
return str(self.root)
|
||||
|
||||
def _episode_files(self) -> dict[tuple[int, int], list[int]]:
|
||||
"""Map each data file ``(chunk_index, file_index)`` to the episode indices it stores."""
|
||||
file_to_eps: dict[tuple[int, int], list[int]] = {}
|
||||
for ep in range(self.meta.total_episodes):
|
||||
row = self.meta.episodes[ep]
|
||||
key = (int(row["data/chunk_index"]), int(row["data/file_index"]))
|
||||
file_to_eps.setdefault(key, []).append(ep)
|
||||
return file_to_eps
|
||||
|
||||
def _validate_row_groups_per_episode(self, sample_files: int = 32) -> None:
|
||||
"""Verify the dataset stores ONE ROW GROUP PER EPISODE so each episode is an independently
|
||||
addressable shard after ``reshard()``. Cheap (footer-only) and sampled.
|
||||
|
||||
Raises:
|
||||
ValueError: if a sampled data file collapses several episodes into fewer row groups, or
|
||||
the whole dataset is one row group per file while holding many more episodes than files.
|
||||
"""
|
||||
import fsspec
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
file_to_eps = self._episode_files()
|
||||
num_data_files = len(file_to_eps)
|
||||
|
||||
# Whole-dataset extreme: reshard() could not split beyond file granularity (one row group per
|
||||
# file) yet there are many more episodes than files -> collapsed.
|
||||
if self.num_shards <= num_data_files and self.meta.total_episodes > self.num_shards:
|
||||
raise ValueError(
|
||||
f"{self.repo_id}: after reshard() the stream still has only {self.num_shards} shard(s) "
|
||||
f"for {self.meta.total_episodes} episodes across {num_data_files} data file(s) — i.e. one "
|
||||
"row group per file. StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
|
||||
"GROUP PER EPISODE so each episode is an independently addressable shard after reshard(). "
|
||||
"Re-emit through the LeRobot writer (one write_table per episode) or fix the aggregate / "
|
||||
"annotate / push_to_hub writer that collapsed the row groups, then re-upload. Recorded "
|
||||
"datasets already satisfy this. Pass validate_row_groups=False to bypass (random-episode "
|
||||
"quality will degrade)."
|
||||
)
|
||||
|
||||
data_root = self._resolve_data_root()
|
||||
rng = np.random.default_rng(self.seed)
|
||||
keys = list(file_to_eps)
|
||||
chosen = rng.choice(len(keys), size=min(sample_files, len(keys)), replace=False)
|
||||
for i in chosen:
|
||||
chunk_idx, file_idx = keys[int(i)]
|
||||
n_ep = len(file_to_eps[(chunk_idx, file_idx)])
|
||||
rel = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path = f"{data_root}/{rel}"
|
||||
with fsspec.open(path, "rb") as f:
|
||||
pf = pq.ParquetFile(f)
|
||||
n_rg = pf.num_row_groups
|
||||
num_rows = pf.metadata.num_rows
|
||||
if n_rg < n_ep:
|
||||
raise ValueError(
|
||||
f"{path}: stored as {n_rg} Parquet row group(s) ({num_rows} rows across "
|
||||
f"{n_ep} episodes). StreamingLeRobotDataset random-episode shuffling requires ONE ROW "
|
||||
"GROUP PER EPISODE so each episode becomes an independently addressable shard after "
|
||||
"reshard(). This file was written by a bulk df.to_parquet / push_to_hub / aggregate "
|
||||
"path that collapses row groups. Re-emit through the LeRobot writer (one write_table "
|
||||
"per episode) or fix the aggregate/annotate writer, then re-upload. Recorded datasets "
|
||||
"already satisfy this. Pass validate_row_groups=False to bypass (quality will degrade)."
|
||||
)
|
||||
|
||||
def _build_pipeline(self) -> datasets.IterableDataset:
|
||||
"""Assemble the native tabular pipeline (everything except video decode)."""
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
if ds.num_shards % self.world_size != 0:
|
||||
logger.warning(
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): "
|
||||
"datasets falls back to example-level splitting where every rank reads (and pays "
|
||||
"for) the full stream. Re-shard the dataset or adjust world size."
|
||||
msg = (
|
||||
f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}). "
|
||||
"After reshard() num_shards == the episode count, and split_dataset_by_node only "
|
||||
"assigns shards evenly when num_shards % world_size == 0; otherwise every rank "
|
||||
"streams (and pays for) the full dataset and keeps only 1/world_size of it. Pin "
|
||||
"world_size to a divisor of the episode count, or drop/pad episodes to a divisible "
|
||||
"count with the dataset tools. Set validate_row_groups=False to downgrade to a warning."
|
||||
)
|
||||
if self.validate_row_groups:
|
||||
raise ValueError(msg)
|
||||
logger.warning(msg)
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
ds = ds.batch(by_column="episode_index")
|
||||
episode_columns = list(ds.column_names or self.hf_dataset.column_names or [])
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size)
|
||||
max_input_shards = max(1, min(self.max_buffer_input_shards, ds.num_shards))
|
||||
ds = ds.shuffle(
|
||||
seed=self.seed,
|
||||
buffer_size=self.episode_pool_size,
|
||||
max_buffer_input_shards=max_input_shards,
|
||||
)
|
||||
# A row-count-changing batched map must drop the input columns explicitly; the exploded
|
||||
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
|
||||
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
|
||||
@@ -358,10 +476,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
|
||||
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
|
||||
# epoch is tracked separately so a mid-iteration state_dict() records the epoch the
|
||||
# stream position actually belongs to.
|
||||
self._in_flight_epoch = self._epoch
|
||||
# stream position actually belongs to. Only advance when shuffling: after reshard() the
|
||||
# stream has one shard per episode, and set_epoch(n>0) re-permutes shard order even without
|
||||
# a shuffle op, so an unshuffled stream must pin epoch 0 to repeat the same order each pass.
|
||||
if self.shuffle:
|
||||
self._in_flight_epoch = self._epoch
|
||||
self._epoch += 1
|
||||
else:
|
||||
self._in_flight_epoch = 0
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self._epoch += 1
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
|
||||
iterator = iter(self._pipeline)
|
||||
|
||||
@@ -312,3 +312,119 @@ def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
|
||||
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
|
||||
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
|
||||
assert state is not None
|
||||
|
||||
|
||||
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
|
||||
|
||||
|
||||
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
|
||||
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
|
||||
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-reshard"
|
||||
total_episodes = 3
|
||||
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
|
||||
# num_shards can reach total_episodes is per-row-group resharding.
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=90,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
file_to_eps = ds._episode_files()
|
||||
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
|
||||
for (chunk_idx, file_idx), eps in file_to_eps.items():
|
||||
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
|
||||
|
||||
assert ds.num_shards == total_episodes
|
||||
|
||||
|
||||
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
|
||||
a single batch should already span most of the live episodes."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-frac"
|
||||
total_episodes = 8
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=240,
|
||||
use_videos=False,
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
episode_pool_size=total_episodes,
|
||||
max_buffer_input_shards=total_episodes,
|
||||
)
|
||||
assert ds.max_buffer_input_shards == total_episodes
|
||||
|
||||
batch = 32
|
||||
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
|
||||
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
|
||||
|
||||
|
||||
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
|
||||
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
|
||||
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
|
||||
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
|
||||
|
||||
|
||||
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
|
||||
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
|
||||
pq.write_table(pq.read_table(parquet_path), parquet_path)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
|
||||
)
|
||||
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
|
||||
|
||||
|
||||
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
|
||||
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
|
||||
stream the whole dataset; the guard must raise instead of silently degrading."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-divis"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
|
||||
)
|
||||
with pytest.raises(ValueError, match="not divisible by world_size"):
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
|
||||
)
|
||||
|
||||
# Bypassing the guard downgrades it to a warning (no raise).
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=3,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
validate_row_groups=False,
|
||||
)
|
||||
assert ds.num_shards == 3
|
||||
|
||||
Vendored
+22
-3
@@ -17,6 +17,7 @@ from pathlib import Path
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -35,6 +36,24 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
|
||||
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
|
||||
|
||||
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
|
||||
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
|
||||
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
|
||||
single row group instead.
|
||||
"""
|
||||
table = hf_dataset.with_format("arrow")[:]
|
||||
episode_index = np.asarray(hf_dataset["episode_index"])
|
||||
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
|
||||
starts = [0, *boundaries.tolist()]
|
||||
ends = [*boundaries.tolist(), len(episode_index)]
|
||||
with pq.ParquetWriter(str(path), table.schema) as writer:
|
||||
for start, end in zip(starts, ends, strict=True):
|
||||
writer.write_table(table.slice(start, end - start))
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
hf_dataset: Dataset,
|
||||
local_dir: Path,
|
||||
@@ -67,7 +86,7 @@ def write_hf_dataset(
|
||||
# If the dataset is small enough, write it to a single file.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
_to_parquet_one_row_group_per_episode(hf_dataset, path)
|
||||
return
|
||||
|
||||
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
|
||||
@@ -114,8 +133,8 @@ def write_hf_dataset(
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the shard to a Parquet file.
|
||||
dataset_shard.to_parquet(path)
|
||||
# Write the shard to a Parquet file (one row group per episode).
|
||||
_to_parquet_one_row_group_per_episode(dataset_shard, path)
|
||||
|
||||
# Update chunk and file indices for the next iteration.
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
|
||||
Reference in New Issue
Block a user