mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
Fix episode filtering bug when requesting a subset of the episodes in a dataset (#2456)
* filter episodes in load_nested_dataset * nit * remove test filtering * move import to module level * added missing episode indices to the EpisodeAwareSampler in lerobot_train.py;
This commit is contained in:
@@ -830,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
hf_dataset = load_nested_dataset(self.root / "data", features=features)
|
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@@ -847,10 +847,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# Determine requested episodes
|
# Determine requested episodes
|
||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
# Requesting all episodes - check if we have all episodes from metadata
|
|
||||||
requested_episodes = set(range(self.meta.total_episodes))
|
requested_episodes = set(range(self.meta.total_episodes))
|
||||||
else:
|
else:
|
||||||
# Requesting specific episodes
|
|
||||||
requested_episodes = set(self.episodes)
|
requested_episodes = set(self.episodes)
|
||||||
|
|
||||||
# Check if all requested episodes are available in cached data
|
# Check if all requested episodes are available in cached data
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import numpy as np
|
|||||||
import packaging.version
|
import packaging.version
|
||||||
import pandas
|
import pandas
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
|
|||||||
return chunk_idx, file_idx
|
return chunk_idx, file_idx
|
||||||
|
|
||||||
|
|
||||||
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
|
def load_nested_dataset(
|
||||||
|
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
|
||||||
|
) -> Dataset:
|
||||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||||
Concatenate all pyarrow references to return HF Dataset format
|
Concatenate all pyarrow references to return HF Dataset format
|
||||||
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
|||||||
Args:
|
Args:
|
||||||
pq_dir: Directory containing parquet files
|
pq_dir: Directory containing parquet files
|
||||||
features: Optional features schema to ensure consistent loading of complex types like images
|
features: Optional features schema to ensure consistent loading of complex types like images
|
||||||
|
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
|
||||||
"""
|
"""
|
||||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||||
if len(paths) == 0:
|
if len(paths) == 0:
|
||||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
|
||||||
with SuppressProgressBars():
|
with SuppressProgressBars():
|
||||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||||
return datasets
|
# PyArrow loads the entire dataset into memory
|
||||||
|
if episodes is None:
|
||||||
|
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||||
|
|
||||||
|
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||||
|
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||||
|
table = arrow_dataset.to_table(filter=filter_expr)
|
||||||
|
|
||||||
|
if features is not None:
|
||||||
|
table = table.cast(features.arrow_schema)
|
||||||
|
|
||||||
|
return Dataset(table)
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||||
|
|||||||
@@ -274,6 +274,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.meta.episodes["dataset_from_index"],
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
dataset.meta.episodes["dataset_to_index"],
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
|
episode_indices_to_use=dataset.episodes,
|
||||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user