From 49402811200ba7afc28077428905102f4718b205 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 15 Jun 2026 13:33:27 +0000 Subject: [PATCH] feat(streaming): random-episode admission via reshard() + multi-input-shard shuffle Reshard parquet per row group (1 shard == 1 row group == 1 episode) and feed the episode-pool shuffle with max_buffer_input_shards so the pool is a uniform random sample of the corpus, independent of episodes-per-file. Add validate_row_groups guardrails (collapsed-row-group + distributed divisibility), require datasets>=5.0.0, make the test fixture write one row group per episode, and plumb max_buffer_input_shards through the dataloading benchmark. Co-authored-by: Cursor --- examples/scaling/benchmark_dataloading.py | 16 +++ pyproject.toml | 7 +- src/lerobot/datasets/streaming_dataset.py | 143 ++++++++++++++++++++-- tests/datasets/test_streaming_native.py | 116 ++++++++++++++++++ tests/fixtures/files.py | 25 +++- 5 files changed, 291 insertions(+), 16 deletions(-) diff --git a/examples/scaling/benchmark_dataloading.py b/examples/scaling/benchmark_dataloading.py index 4da1f2697..40e4910ee 100644 --- a/examples/scaling/benchmark_dataloading.py +++ b/examples/scaling/benchmark_dataloading.py @@ -222,8 +222,12 @@ def build_dataset(scenario: str, args: argparse.Namespace): revision=MOLMO_REVISION, data_files_root=data_files_root, episode_pool_size=args.episode_pool_size, + max_buffer_input_shards=args.max_buffer_input_shards, video_decoder_cache_size=args.video_decoder_cache_size, tolerance_s=1e-3, + # Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public + # dataset may be collapsed); reshard() still yields per-episode shards where it holds. + validate_row_groups=False, ) return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root} @@ -362,6 +366,9 @@ def run_scenario(scenario: str, args: argparse.Namespace) -> None: "batch_size": args.batch_size, "num_workers": args.num_workers, "episode_pool_size": None if is_map_style else args.episode_pool_size, + "max_buffer_input_shards": None + if is_map_style + else (args.max_buffer_input_shards or args.episode_pool_size), **info, "num_cameras": num_cameras, "image_shape": image_shape, @@ -419,6 +426,8 @@ def submit_chain(args: argparse.Namespace) -> None: f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} " f"--num_batches {args.num_batches} --out_dir {args.out_dir}" ) + if args.max_buffer_input_shards is not None: + common += f" --max_buffer_input_shards {args.max_buffer_input_shards}" if args.local_root: common += f" --local_root {args.local_root}" env_prefix = "export TOKENIZERS_PARALLELISM=false" @@ -491,6 +500,13 @@ def parse_args() -> argparse.Namespace: p.add_argument( "--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)." ) + p.add_argument( + "--max_buffer_input_shards", + type=int, + default=None, + help="Concurrently-live random episodes feeding the pool after reshard() " + "(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.", + ) p.add_argument( "--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)." ) diff --git a/pyproject.toml b/pyproject.toml index 10bdd0cf7..42116722a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ # ── Feature-scoped extras ────────────────────────────────── dataset = [ - "datasets>=4.7.0,<6.0.0", + "datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...) "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets "lerobot[av-dep]", @@ -334,8 +334,9 @@ explicit = true torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] # Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle -# re-creation, fixed in datasets#8259 (merged, not yet released). Pin to the merge commit until the -# next datasets release ships it, then drop this and bump the floor in `dependencies`. +# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...) +# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next +# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`. datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" } [tool.setuptools.package-data] diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index e841df89b..fb6b9eef2 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -51,9 +51,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): The tabular side is a pure `datasets` pipeline:: load_dataset(streaming=True) # parquet shards from the Hub / a bucket + -> reshard() # 1 shard == 1 row group == 1 episode -> split_dataset_by_node(rank, world_size) # disjoint shards per rank - -> batch(by_column="episode_index") # whole episodes - -> shuffle(buffer_size=episode_pool_size) # episode pool (the randomness knob) + -> batch(by_column="episode_index") # whole episodes (one per shard) + -> shuffle(episode_pool_size, max_buffer_input_shards) # K random episodes, global perm -> map(explode + exact delta windows) # episode -> frames, windows are exact -> shuffle(buffer_size=frame_shuffle_buffer_size) # frame-level interleave @@ -62,6 +63,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): string. DataLoader workers are split natively by `datasets` (disjoint shards per worker), and resume uses the native ``state_dict`` / ``load_state_dict``. + Random-episode admission (Plan B): the LeRobot writer stores one Parquet row group per + episode, so ``datasets.IterableDataset.reshard()`` makes one shard == one episode (no new + files; shards are (file, row_group) pairs). ``shuffle`` then permutes shard order globally and + fills its buffer from ``max_buffer_input_shards`` shards concurrently, so the episode pool is a + uniformly-random sample of the corpus regardless of how many episodes are packed per file. + ``max_buffer_input_shards`` is the number of concurrently-live random episodes; set it + ``>= batch_size`` for the per-batch distinct-episode fraction to approach 1. + + Requirement: ONE ROW GROUP PER EPISODE. Recorded datasets satisfy this; bulk + ``df.to_parquet`` / ``push_to_hub`` / aggregate paths collapse row groups and are rejected at + init (see ``validate_row_groups``). Old collapsed datasets still load fine for the map-style + path; only this streaming random-episode path requires the invariant. + Randomness: a batch mixes up to ``episode_pool_size`` distinct episodes; delta windows are exact slices of the resident episode with correct padding at episode boundaries. @@ -97,6 +111,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): force_cache_sync: bool = False, streaming: bool = True, episode_pool_size: int | None = 1024, + max_buffer_input_shards: int | None = None, frame_shuffle_buffer_size: int | None = None, buffer_size: int | None = None, max_num_shards: int | None = None, @@ -108,6 +123,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): world_size: int | None = None, video_decoder_cache_size: int | None = None, data_files_root: str | None = None, + validate_row_groups: bool = True, ): """Initialize a StreamingLeRobotDataset. @@ -127,6 +143,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): across — the randomness knob. Larger mixes more episodes per batch (closer to map-style uniform) at the cost of cold-start latency and frame-buffer RAM. Defaults to 1024. + max_buffer_input_shards (int | None, optional): Number of shards (== episodes, after + ``reshard()``) the episode-pool ``shuffle`` reads from concurrently — i.e. the count + of concurrently-live random episodes feeding the pool from a global shard permutation. + Set ``>= batch_size`` for the per-batch distinct-episode fraction to approach 1. + Defaults to ``episode_pool_size``. frame_shuffle_buffer_size (int | None, optional): Frame-level shuffle buffer after the episode pool. Defaults to ``episode_pool_size x average episode length`` (capped), which matches the pool's mixing radius. @@ -147,6 +168,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/`` trees (e.g. ``hf://buckets//``). When set, parquet and video bytes are read from there while metadata still loads from ``repo_id`` on the Hub. + validate_row_groups (bool, optional): When True (default), verify at init that the dataset + stores one Parquet row group per episode (sampling data-file footers) and that + ``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear + ``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging); + the divisibility check then downgrades to a warning. """ super().__init__() self.repo_id = repo_id @@ -175,6 +201,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.streaming = streaming self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 1024 + self.max_buffer_input_shards = ( + max(1, max_buffer_input_shards) if max_buffer_input_shards else self.episode_pool_size + ) + self.validate_row_groups = validate_row_groups self._return_uint8 = return_uint8 self.rank, self.world_size = self._resolve_distributed(rank, world_size) @@ -232,8 +262,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): if extra_columns: self.hf_dataset = self.hf_dataset.remove_columns(extra_columns) + # Reshard Parquet per row group so 1 shard == 1 row group == 1 episode (the LeRobot writer + # emits one row group per episode). This lets the episode-pool shuffle admit uniformly-random + # episodes from a global shard permutation, independent of how many episodes are packed per file. + if self.streaming: + self.hf_dataset = self.hf_dataset.reshard() self.num_shards = self.hf_dataset.num_shards + if self.validate_row_groups and self.streaming: + self._validate_row_groups_per_episode() + avg_episode_len = max(1, round(self.meta.total_frames / max(1, self.meta.total_episodes))) self.frame_shuffle_buffer_size = ( frame_shuffle_buffer_size @@ -283,22 +321,102 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return 0, 1 + def _resolve_data_root(self) -> str: + """fsspec root that holds the bulk ``data/`` parquet tree (revision-qualified for the Hub).""" + if self.data_files_root is not None: + return self.data_files_root + if self.streaming and not self.streaming_from_local: + return f"hf://datasets/{self.repo_id}@{self.revision}" + return str(self.root) + + def _episode_files(self) -> dict[tuple[int, int], list[int]]: + """Map each data file ``(chunk_index, file_index)`` to the episode indices it stores.""" + file_to_eps: dict[tuple[int, int], list[int]] = {} + for ep in range(self.meta.total_episodes): + row = self.meta.episodes[ep] + key = (int(row["data/chunk_index"]), int(row["data/file_index"])) + file_to_eps.setdefault(key, []).append(ep) + return file_to_eps + + def _validate_row_groups_per_episode(self, sample_files: int = 32) -> None: + """Verify the dataset stores ONE ROW GROUP PER EPISODE so each episode is an independently + addressable shard after ``reshard()``. Cheap (footer-only) and sampled. + + Raises: + ValueError: if a sampled data file collapses several episodes into fewer row groups, or + the whole dataset is one row group per file while holding many more episodes than files. + """ + import fsspec + import pyarrow.parquet as pq + + file_to_eps = self._episode_files() + num_data_files = len(file_to_eps) + + # Whole-dataset extreme: reshard() could not split beyond file granularity (one row group per + # file) yet there are many more episodes than files -> collapsed. + if self.num_shards <= num_data_files and self.meta.total_episodes > self.num_shards: + raise ValueError( + f"{self.repo_id}: after reshard() the stream still has only {self.num_shards} shard(s) " + f"for {self.meta.total_episodes} episodes across {num_data_files} data file(s) — i.e. one " + "row group per file. StreamingLeRobotDataset random-episode shuffling requires ONE ROW " + "GROUP PER EPISODE so each episode is an independently addressable shard after reshard(). " + "Re-emit through the LeRobot writer (one write_table per episode) or fix the aggregate / " + "annotate / push_to_hub writer that collapsed the row groups, then re-upload. Recorded " + "datasets already satisfy this. Pass validate_row_groups=False to bypass (random-episode " + "quality will degrade)." + ) + + data_root = self._resolve_data_root() + rng = np.random.default_rng(self.seed) + keys = list(file_to_eps) + chosen = rng.choice(len(keys), size=min(sample_files, len(keys)), replace=False) + for i in chosen: + chunk_idx, file_idx = keys[int(i)] + n_ep = len(file_to_eps[(chunk_idx, file_idx)]) + rel = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path = f"{data_root}/{rel}" + with fsspec.open(path, "rb") as f: + pf = pq.ParquetFile(f) + n_rg = pf.num_row_groups + num_rows = pf.metadata.num_rows + if n_rg < n_ep: + raise ValueError( + f"{path}: stored as {n_rg} Parquet row group(s) ({num_rows} rows across " + f"{n_ep} episodes). StreamingLeRobotDataset random-episode shuffling requires ONE ROW " + "GROUP PER EPISODE so each episode becomes an independently addressable shard after " + "reshard(). This file was written by a bulk df.to_parquet / push_to_hub / aggregate " + "path that collapses row groups. Re-emit through the LeRobot writer (one write_table " + "per episode) or fix the aggregate/annotate writer, then re-upload. Recorded datasets " + "already satisfy this. Pass validate_row_groups=False to bypass (quality will degrade)." + ) + def _build_pipeline(self) -> datasets.IterableDataset: """Assemble the native tabular pipeline (everything except video decode).""" ds = self.hf_dataset if self.world_size > 1: if ds.num_shards % self.world_size != 0: - logger.warning( - f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}): " - "datasets falls back to example-level splitting where every rank reads (and pays " - "for) the full stream. Re-shard the dataset or adjust world size." + msg = ( + f"num_shards ({ds.num_shards}) is not divisible by world_size ({self.world_size}). " + "After reshard() num_shards == the episode count, and split_dataset_by_node only " + "assigns shards evenly when num_shards % world_size == 0; otherwise every rank " + "streams (and pays for) the full dataset and keeps only 1/world_size of it. Pin " + "world_size to a divisor of the episode count, or drop/pad episodes to a divisible " + "count with the dataset tools. Set validate_row_groups=False to downgrade to a warning." ) + if self.validate_row_groups: + raise ValueError(msg) + logger.warning(msg) ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size) ds = ds.batch(by_column="episode_index") episode_columns = list(ds.column_names or self.hf_dataset.column_names or []) if self.shuffle: - ds = ds.shuffle(seed=self.seed, buffer_size=self.episode_pool_size) + max_input_shards = max(1, min(self.max_buffer_input_shards, ds.num_shards)) + ds = ds.shuffle( + seed=self.seed, + buffer_size=self.episode_pool_size, + max_buffer_input_shards=max_input_shards, + ) # A row-count-changing batched map must drop the input columns explicitly; the exploded # frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks). ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns) @@ -358,10 +476,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch); # DataLoader workers each advance their own copy's counter in lockstep. The in-flight # epoch is tracked separately so a mid-iteration state_dict() records the epoch the - # stream position actually belongs to. - self._in_flight_epoch = self._epoch + # stream position actually belongs to. Only advance when shuffling: after reshard() the + # stream has one shard per episode, and set_epoch(n>0) re-permutes shard order even without + # a shuffle op, so an unshuffled stream must pin epoch 0 to repeat the same order each pass. + if self.shuffle: + self._in_flight_epoch = self._epoch + self._epoch += 1 + else: + self._in_flight_epoch = 0 self._pipeline.set_epoch(self._in_flight_epoch) - self._epoch += 1 self.video_decoder_cache = self._make_video_decoder_cache() iterator = iter(self._pipeline) diff --git a/tests/datasets/test_streaming_native.py b/tests/datasets/test_streaming_native.py index de25eb144..91ded1f3c 100644 --- a/tests/datasets/test_streaming_native.py +++ b/tests/datasets/test_streaming_native.py @@ -312,3 +312,119 @@ def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory): assert isinstance(ds._pipeline, hf_datasets.IterableDataset) state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end assert state is not None + + +# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle --- + + +def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory): + """With one row group per episode (the writer's invariant), reshard() turns each episode into its + own shard, so num_shards == total_episodes even when many episodes share a single data file.""" + import pyarrow.parquet as pq + + repo_id = f"{DUMMY_REPO_ID}-reshard" + total_episodes = 3 + # Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way + # num_shards can reach total_episodes is per-row-group resharding. + lerobot_dataset_factory( + root=tmp_path / "ds", + repo_id=repo_id, + total_episodes=total_episodes, + total_frames=90, + use_videos=False, + ) + ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3) + + file_to_eps = ds._episode_files() + assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file" + for (chunk_idx, file_idx), eps in file_to_eps.items(): + rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps) + + assert ds.num_shards == total_episodes + + +def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory): + """max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix: + a single batch should already span most of the live episodes.""" + repo_id = f"{DUMMY_REPO_ID}-frac" + total_episodes = 8 + lerobot_dataset_factory( + root=tmp_path / "ds", + repo_id=repo_id, + total_episodes=total_episodes, + total_frames=240, + use_videos=False, + ) + ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=True, + seed=0, + episode_pool_size=total_episodes, + max_buffer_input_shards=total_episodes, + ) + assert ds.max_buffer_input_shards == total_episodes + + batch = 32 + head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)} + assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}" + + +def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory): + """A data file that collapses several episodes into a single row group (bulk df.to_parquet / + push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes.""" + import pyarrow.parquet as pq + + repo_id = f"{DUMMY_REPO_ID}-collapsed" + lerobot_dataset_factory( + root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False + ) + # Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse). + for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"): + pq.write_table(pq.read_table(parquet_path), parquet_path) + + with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"): + StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3) + + +def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory): + """validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded).""" + import pyarrow.parquet as pq + + repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass" + lerobot_dataset_factory( + root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False + ) + for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"): + pq.write_table(pq.read_table(parquet_path), parquet_path) + + ds = StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False + ) + assert sorted(int(frame["index"]) for frame in ds) == list(range(90)) + + +def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory): + """When num_shards (== episodes after reshard) is not divisible by world_size, every rank would + stream the whole dataset; the guard must raise instead of silently degrading.""" + repo_id = f"{DUMMY_REPO_ID}-divis" + lerobot_dataset_factory( + root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False + ) + with pytest.raises(ValueError, match="not divisible by world_size"): + StreamingLeRobotDataset( + repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2 + ) + + # Bypassing the guard downgrades it to a warning (no raise). + ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=tmp_path / "ds", + shuffle=False, + episode_pool_size=3, + rank=0, + world_size=2, + validate_row_groups=False, + ) + assert ds.num_shards == 3 diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 92d9ca1e2..3d5666947 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -17,6 +17,7 @@ from pathlib import Path import datasets import numpy as np import pandas as pd +import pyarrow.parquet as pq import pytest from datasets import Dataset @@ -35,6 +36,24 @@ from lerobot.datasets.utils import ( ) +def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None: + """Write ``hf_dataset`` to ``path`` with one Parquet row group per episode. + + Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an + independently addressable shard after ``datasets.IterableDataset.reshard()``, which + ``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a + single row group instead. + """ + table = hf_dataset.with_format("arrow")[:] + episode_index = np.asarray(hf_dataset["episode_index"]) + boundaries = np.where(np.diff(episode_index) != 0)[0] + 1 + starts = [0, *boundaries.tolist()] + ends = [*boundaries.tolist(), len(episode_index)] + with pq.ParquetWriter(str(path), table.schema) as writer: + for start, end in zip(starts, ends, strict=True): + writer.write_table(table.slice(start, end - start)) + + def write_hf_dataset( hf_dataset: Dataset, local_dir: Path, @@ -67,7 +86,7 @@ def write_hf_dataset( # If the dataset is small enough, write it to a single file. path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) path.parent.mkdir(parents=True, exist_ok=True) - hf_dataset.to_parquet(path) + _to_parquet_one_row_group_per_episode(hf_dataset, path) return # If the dataset is too large, split it into smaller chunks, keeping episodes whole. @@ -114,8 +133,8 @@ def write_hf_dataset( path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - # Write the shard to a Parquet file. - dataset_shard.to_parquet(path) + # Write the shard to a Parquet file (one row group per episode). + _to_parquet_one_row_group_per_episode(dataset_shard, path) # Update chunk and file indices for the next iteration. chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)