try fix 4

This commit is contained in:
Steven Palma
2025-11-05 21:26:52 +01:00
parent 00a4e6bfb3
commit 20333abc72
+7 -10
View File
@@ -863,17 +863,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
features = get_hf_features_from_features(self.features)
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
# This is the key to performance for dtype="image" datasets.
# 1. Check if specific episodes are requested by the user.
# This is the "data_files" path, which may be slow, but is
# necessary for visualization or evaluation on a subset.
if self.episodes is not None:
# Get the unique set of parquet files for the requested episodes
fpaths = set()
for ep_idx in self.episodes:
# Need to read metadata to find the file path for this episode
# We use self.meta.episodes (the loaded dataset) here
ep_meta = self.meta.episodes[ep_idx]
# Use the pre-loaded metadata list
ep_meta = self.episodes_metadata_list[ep_idx]
chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"]
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
@@ -886,17 +885,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
# Filter the loaded dataset to *only* include the requested episodes
# This is necessary because the v3 files contain multiple episodes.
# This is necessary because v3 files can contain multiple episodes.
requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set,
batched=True, # Use batched=True for faster filtering
batch_size=1000,
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
)
else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# Load all data files using data_dir, which is the most efficient.
# We must use `data_dir` to trigger the v2.1-style efficient cache.
data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")