From b464d9f8bc18be4288c98b1e1aff907b1f58c92b Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 18 Nov 2025 17:26:41 +0100 Subject: [PATCH] Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456) * filter episodes in load_nested_dataset * nit * remove test filtering * move import to module level * added missing episode indices to the EpisodeAwareSampler in lerobot_train.py; --- src/lerobot/datasets/lerobot_dataset.py | 4 +--- src/lerobot/datasets/utils.py | 22 ++++++++++++++++++---- src/lerobot/scripts/lerobot_train.py | 1 + 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 48608a809..29436c4d2 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -830,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self.features) - hf_dataset = load_nested_dataset(self.root / "data", features=features) + hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -847,10 +847,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # Determine requested episodes if self.episodes is None: - # Requesting all episodes - check if we have all episodes from metadata requested_episodes = set(range(self.meta.total_episodes)) else: - # Requesting specific episodes requested_episodes = set(self.episodes) # Check if all requested episodes are available in cached data diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 37d8432b2..ce4cf1da1 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -28,6 +28,7 @@ import numpy as np import packaging.version import pandas import pandas as pd +import pyarrow.dataset as pa_ds import pyarrow.parquet as pq import torch from datasets import Dataset @@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) - return chunk_idx, file_idx -def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset: +def load_nested_dataset( + pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None +) -> Dataset: """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage Concatenate all pyarrow references to return HF Dataset format @@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) Args: pq_dir: Directory containing parquet files features: Optional features schema to ensure consistent loading of complex types like images + episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency. """ paths = sorted(pq_dir.glob("*/*.parquet")) if len(paths) == 0: raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") - # TODO(rcadene): set num_proc to accelerate conversion to pyarrow with SuppressProgressBars(): - datasets = Dataset.from_parquet([str(path) for path in paths], features=features) - return datasets + # When no filtering needed, Dataset uses memory-mapped loading for efficiency + # PyArrow loads the entire dataset into memory + if episodes is None: + return Dataset.from_parquet([str(path) for path in paths], features=features) + + arrow_dataset = pa_ds.dataset(paths, format="parquet") + filter_expr = pa_ds.field("episode_index").isin(episodes) + table = arrow_dataset.to_table(filter=filter_expr) + + if features is not None: + table = table.cast(features.arrow_schema) + + return Dataset(table) def get_parquet_num_frames(parquet_path: str | Path) -> int: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0cc6e037f..e6cf40b9b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -274,6 +274,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], + episode_indices_to_use=dataset.episodes, drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, )