mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)
* filter episodes in load_nested_dataset * nit * remove test filtering * move import to module level * added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
This commit is contained in:
@@ -830,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self.features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -847,10 +847,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Determine requested episodes
|
||||
if self.episodes is None:
|
||||
# Requesting all episodes - check if we have all episodes from metadata
|
||||
requested_episodes = set(range(self.meta.total_episodes))
|
||||
else:
|
||||
# Requesting specific episodes
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
# Check if all requested episodes are available in cached data
|
||||
|
||||
@@ -28,6 +28,7 @@ import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
|
||||
def load_nested_dataset(
|
||||
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
|
||||
) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
||||
Args:
|
||||
pq_dir: Directory containing parquet files
|
||||
features: Optional features schema to ensure consistent loading of complex types like images
|
||||
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
with SuppressProgressBars():
|
||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
return datasets
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -274,6 +274,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user