diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 547318519..f348b40eb 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -259,15 +259,6 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: return cast_stats_to_numpy(stats) -def write_hf_dataset(hf_dataset: Dataset, local_dir: Path): - if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB: - raise NotImplementedError("Contact a maintainer.") - - path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) - path.parent.mkdir(parents=True, exist_ok=True) - hf_dataset.to_parquet(path) - - def write_tasks(tasks: pandas.DataFrame, local_dir: Path): path = local_dir / DEFAULT_TASKS_PATH path.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 6ea9a8105..783058014 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -488,6 +488,8 @@ def lerobot_dataset_factory( tasks: pd.DataFrame | None = None, episodes_metadata: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None, + data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, + chunks_size: int = DEFAULT_CHUNK_SIZE, **kwargs, ) -> LeRobotDataset: # Instantiate objects @@ -497,6 +499,8 @@ def lerobot_dataset_factory( total_frames=total_frames, total_tasks=total_tasks, use_videos=use_videos, + data_files_size_in_mb=data_files_size_in_mb, + chunks_size=chunks_size, ) if stats is None: stats = stats_factory(features=info["features"]) @@ -525,6 +529,8 @@ def lerobot_dataset_factory( tasks=tasks, episodes=episodes_metadata, hf_dataset=hf_dataset, + data_files_size_in_mb=data_files_size_in_mb, + chunks_size=chunks_size, ) mock_metadata = lerobot_dataset_metadata_factory( root=root, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 6441f8303..2f2980046 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -11,23 +11,117 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from pathlib import Path import datasets +import numpy as np import pandas as pd import pyarrow.compute as pc import pyarrow.parquet as pq import pytest +from datasets import Dataset from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + get_hf_dataset_size_in_mb, + update_chunk_file_indices, write_episodes, - write_hf_dataset, write_info, write_stats, write_tasks, ) +def write_hf_dataset( + hf_dataset: Dataset, + local_dir: Path, + data_file_size_mb: float | None = None, + chunk_size: int | None = None, +): + """ + Writes a Hugging Face Dataset to one or more Parquet files in a structured directory format. + + If the dataset size is within `DEFAULT_DATA_FILE_SIZE_IN_MB`, it's saved as a single file. + Otherwise, the dataset is split into multiple smaller Parquet files, each not exceeding the size limit. + The file and chunk indices are managed to organize the output files in a hierarchical structure, + e.g., `data/chunk-000/file-000.parquet`, `data/chunk-000/file-001.parquet`, etc. + This function ensures that episodes are not split across multiple files. + + Args: + hf_dataset (Dataset): The Hugging Face Dataset to be written to disk. + local_dir (Path): The root directory where the dataset files will be stored. + data_file_size_mb (float, optional): Maximal size for the parquet data file, in MB. Defaults to DEFAULT_DATA_FILE_SIZE_IN_MB. + chunk_size (int, optional): Maximal number of files within a chunk folder before creating another one. Defaults to DEFAULT_CHUNK_SIZE. + """ + if data_file_size_mb is None: + data_file_size_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + dataset_size_in_mb = get_hf_dataset_size_in_mb(hf_dataset) + + if dataset_size_in_mb <= data_file_size_mb: + # If the dataset is small enough, write it to a single file. + path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) + path.parent.mkdir(parents=True, exist_ok=True) + hf_dataset.to_parquet(path) + return + + # If the dataset is too large, split it into smaller chunks, keeping episodes whole. + episode_indices = np.array(hf_dataset["episode_index"]) + episode_boundaries = np.where(np.diff(episode_indices) != 0)[0] + 1 + episode_starts = np.concatenate(([0], episode_boundaries)) + episode_ends = np.concatenate((episode_boundaries, [len(hf_dataset)])) + + num_episodes = len(episode_starts) + current_episode_idx = 0 + chunk_idx, file_idx = 0, 0 + + while current_episode_idx < num_episodes: + shard_start_row = episode_starts[current_episode_idx] + shard_end_row = episode_ends[current_episode_idx] + next_episode_to_try_idx = current_episode_idx + 1 + + while next_episode_to_try_idx < num_episodes: + potential_shard_end_row = episode_ends[next_episode_to_try_idx] + dataset_shard_candidate = hf_dataset.select(range(shard_start_row, potential_shard_end_row)) + shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard_candidate) + + if shard_size_mb > data_file_size_mb: + break + else: + shard_end_row = potential_shard_end_row + next_episode_to_try_idx += 1 + + dataset_shard = hf_dataset.select(range(shard_start_row, shard_end_row)) + + if ( + shard_start_row == episode_starts[current_episode_idx] + and shard_end_row == episode_ends[current_episode_idx] + ): + shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard) + if shard_size_mb > data_file_size_mb: + logging.warning( + f"Episode with index {hf_dataset[shard_start_row.item()]['episode_index']} has size {shard_size_mb:.2f}MB, " + f"which is larger than data_file_size_mb ({data_file_size_mb}MB). " + "Writing it to a separate shard anyway to preserve episode integrity." + ) + + # Define the path for the current shard and ensure the directory exists. + path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + # Write the shard to a Parquet file. + dataset_shard.to_parquet(path) + + # Update chunk and file indices for the next iteration. + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) + current_episode_idx = next_episode_to_try_idx + + @pytest.fixture(scope="session") def create_info(info_factory): def _create_info(dir: Path, info: dict | None = None): @@ -81,10 +175,15 @@ def create_episodes(episodes_factory): @pytest.fixture(scope="session") def create_hf_dataset(hf_dataset_factory): - def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None): + def _create_hf_dataset( + dir: Path, + hf_dataset: datasets.Dataset | None = None, + data_file_size_in_mb: float | None = None, + chunk_size: int | None = None, + ): if hf_dataset is None: hf_dataset = hf_dataset_factory() - write_hf_dataset(hf_dataset, dir) + write_hf_dataset(hf_dataset, dir, data_file_size_in_mb, chunk_size) return _create_hf_dataset diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 97f77158a..4333b91a3 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -19,6 +19,8 @@ import pytest from huggingface_hub.utils import filter_repo_objects from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, @@ -54,9 +56,11 @@ def mock_snapshot_download_factory( tasks: pd.DataFrame | None = None, episodes: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None, + data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, + chunks_size: int = DEFAULT_CHUNK_SIZE, ): if info is None: - info = info_factory() + info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size) if stats is None: stats = stats_factory(features=info["features"]) if tasks is None: @@ -132,7 +136,7 @@ def mock_snapshot_download_factory( if request_episodes: create_episodes(local_dir, episodes) if request_data: - create_hf_dataset(local_dir, hf_dataset) + create_hf_dataset(local_dir, hf_dataset, data_files_size_in_mb, chunks_size) if request_videos: create_videos(root=local_dir, info=info)