try fix 6

This commit is contained in:
Steven Palma
2025-11-05 21:42:31 +01:00
parent bbcffc4999
commit e195f8d287
+21 -21
View File
@@ -81,23 +81,30 @@ CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]: def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """
to torch tensors. ... Converts a single item (row) from a Hugging Face dataset to torch tensors.
[This is the v2.1 item-level transform] This item-level transform allows `datasets` to build an efficient, pre-processed cache.
""" """
for key in items_dict: for key in items_dict:
if items_dict[key] is None: item = items_dict[key]
if item is None:
continue continue
if isinstance(items_dict[key], PILImage.Image):
# PIL image (h w c) (uint8) if isinstance(item, PILImage.Image):
to_tensor = transforms.ToTensor() # Correctly transform PIL images to tensors
items_dict[key] = to_tensor(items_dict[key]) items_dict[key] = transforms.ToTensor()(item)
elif isinstance(items_dict[key], str): elif isinstance(item, (str, bytes)):
# keep as is # Let strings (like 'task') pass through untouched
pass pass
else: else:
# This handles tensors, ints, floats, etc. # Convert all other numeric types (int, float, list, np.ndarray) to tensors
items_dict[key] = torch.tensor(items_dict[key]) try:
items_dict[key] = torch.tensor(item)
except Exception as e:
# Catch errors like the one you saw
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
raise e
return items_dict return items_dict
@@ -857,15 +864,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
features = get_hf_features_from_features(self.features) features = get_hf_features_from_features(self.features)
# 1. Check if specific episodes are requested by the user.
# This is the "data_files" path, which may be slow, but is
# necessary for visualization or evaluation on a subset.
if self.episodes is not None: if self.episodes is not None:
# Get the unique set of parquet files for the requested episodes # Path for episode-specific loading (e.g., visualization)
fpaths = set() fpaths = set()
for ep_idx in self.episodes: for ep_idx in self.episodes:
# Need to read metadata to find the file path for this episode
# Use the pre-loaded metadata list
ep_meta = self.episodes_metadata_list[ep_idx] ep_meta = self.episodes_metadata_list[ep_idx]
chunk_idx = ep_meta["data/chunk_index"] chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"] file_idx = ep_meta["data/file_index"]
@@ -878,8 +880,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"parquet", data_files=data_files, features=features, split="train" "parquet", data_files=data_files, features=features, split="train"
) )
# Filter the loaded dataset to *only* include the requested episodes
# This is necessary because v3 files can contain multiple episodes.
requested_episodes_set = set(self.episodes) requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter( hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000 lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
@@ -887,7 +887,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None) # THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# We must use `data_dir` to trigger the v2.1-style efficient cache. # Use `data_dir` to trigger the v2.1-style efficient cache.
data_dir = str(self.root / "data") data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train") hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")