mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
try fix 4
This commit is contained in:
@@ -863,17 +863,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
features = get_hf_features_from_features(self.features)
|
||||
|
||||
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
|
||||
# This is the key to performance for dtype="image" datasets.
|
||||
|
||||
# 1. Check if specific episodes are requested by the user.
|
||||
# This is the "data_files" path, which may be slow, but is
|
||||
# necessary for visualization or evaluation on a subset.
|
||||
if self.episodes is not None:
|
||||
# Get the unique set of parquet files for the requested episodes
|
||||
fpaths = set()
|
||||
for ep_idx in self.episodes:
|
||||
# Need to read metadata to find the file path for this episode
|
||||
# We use self.meta.episodes (the loaded dataset) here
|
||||
ep_meta = self.meta.episodes[ep_idx]
|
||||
# Use the pre-loaded metadata list
|
||||
ep_meta = self.episodes_metadata_list[ep_idx]
|
||||
chunk_idx = ep_meta["data/chunk_index"]
|
||||
file_idx = ep_meta["data/file_index"]
|
||||
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
@@ -886,17 +885,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
# Filter the loaded dataset to *only* include the requested episodes
|
||||
# This is necessary because the v3 files contain multiple episodes.
|
||||
# This is necessary because v3 files can contain multiple episodes.
|
||||
requested_episodes_set = set(self.episodes)
|
||||
hf_dataset = hf_dataset.filter(
|
||||
lambda x: x["episode_index"] in requested_episodes_set,
|
||||
batched=True, # Use batched=True for faster filtering
|
||||
batch_size=1000,
|
||||
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
|
||||
)
|
||||
|
||||
else:
|
||||
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
|
||||
# Load all data files using data_dir, which is the most efficient.
|
||||
# We must use `data_dir` to trigger the v2.1-style efficient cache.
|
||||
data_dir = str(self.root / "data")
|
||||
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user