diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index a1aeae0ce..9484a4342 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -863,17 +863,16 @@ class LeRobotDataset(torch.utils.data.Dataset): features = get_hf_features_from_features(self.features) - # This is the v2.1 logic that forces an efficient, pre-decoded cache build. - # This is the key to performance for dtype="image" datasets. - # 1. Check if specific episodes are requested by the user. + # This is the "data_files" path, which may be slow, but is + # necessary for visualization or evaluation on a subset. if self.episodes is not None: # Get the unique set of parquet files for the requested episodes fpaths = set() for ep_idx in self.episodes: # Need to read metadata to find the file path for this episode - # We use self.meta.episodes (the loaded dataset) here - ep_meta = self.meta.episodes[ep_idx] + # Use the pre-loaded metadata list + ep_meta = self.episodes_metadata_list[ep_idx] chunk_idx = ep_meta["data/chunk_index"] file_idx = ep_meta["data/file_index"] fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) @@ -886,17 +885,15 @@ class LeRobotDataset(torch.utils.data.Dataset): ) # Filter the loaded dataset to *only* include the requested episodes - # This is necessary because the v3 files contain multiple episodes. + # This is necessary because v3 files can contain multiple episodes. requested_episodes_set = set(self.episodes) hf_dataset = hf_dataset.filter( - lambda x: x["episode_index"] in requested_episodes_set, - batched=True, # Use batched=True for faster filtering - batch_size=1000, + lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000 ) else: # THIS IS THE FAST PATH FOR TRAINING (self.episodes is None) - # Load all data files using data_dir, which is the most efficient. + # We must use `data_dir` to trigger the v2.1-style efficient cache. data_dir = str(self.root / "data") hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")