Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)

* filter episodes in load_nested_dataset

* nit

* remove test filtering

* move import to module level

* added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
This commit is contained in:
Michel Aractingi
2025-11-18 17:26:41 +01:00
committed by GitHub
parent 784cdae55a
commit b464d9f8bc
3 changed files with 20 additions and 7 deletions
+1 -3
View File
@@ -830,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset: def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features) features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features) hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset return hf_dataset
@@ -847,10 +847,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Determine requested episodes # Determine requested episodes
if self.episodes is None: if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes)) requested_episodes = set(range(self.meta.total_episodes))
else: else:
# Requesting specific episodes
requested_episodes = set(self.episodes) requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data # Check if all requested episodes are available in cached data
+18 -4
View File
@@ -28,6 +28,7 @@ import numpy as np
import packaging.version import packaging.version
import pandas import pandas
import pandas as pd import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
from datasets import Dataset from datasets import Dataset
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx return chunk_idx, file_idx
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset: def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format Concatenate all pyarrow references to return HF Dataset format
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
Args: Args:
pq_dir: Directory containing parquet files pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
""" """
paths = sorted(pq_dir.glob("*/*.parquet")) paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0: if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars(): with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features) # When no filtering needed, Dataset uses memory-mapped loading for efficiency
return datasets # PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
def get_parquet_num_frames(parquet_path: str | Path) -> int: def get_parquet_num_frames(parquet_path: str | Path) -> int:
+1
View File
@@ -274,6 +274,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler = EpisodeAwareSampler( sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"], dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames, drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True, shuffle=True,
) )