try fix 6

This commit is contained in:
Steven Palma
2025-11-05 21:42:31 +01:00
parent bbcffc4999
commit e195f8d287
+21 -21
View File
@@ -81,23 +81,30 @@ CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. ...
[This is the v2.1 item-level transform]
"""
Converts a single item (row) from a Hugging Face dataset to torch tensors.
This item-level transform allows `datasets` to build an efficient, pre-processed cache.
"""
for key in items_dict:
if items_dict[key] is None:
item = items_dict[key]
if item is None:
continue
if isinstance(items_dict[key], PILImage.Image):
# PIL image (h w c) (uint8)
to_tensor = transforms.ToTensor()
items_dict[key] = to_tensor(items_dict[key])
elif isinstance(items_dict[key], str):
# keep as is
if isinstance(item, PILImage.Image):
# Correctly transform PIL images to tensors
items_dict[key] = transforms.ToTensor()(item)
elif isinstance(item, (str, bytes)):
# Let strings (like 'task') pass through untouched
pass
else:
# This handles tensors, ints, floats, etc.
items_dict[key] = torch.tensor(items_dict[key])
# Convert all other numeric types (int, float, list, np.ndarray) to tensors
try:
items_dict[key] = torch.tensor(item)
except Exception as e:
# Catch errors like the one you saw
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
raise e
return items_dict
@@ -857,15 +864,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
features = get_hf_features_from_features(self.features)
# 1. Check if specific episodes are requested by the user.
# This is the "data_files" path, which may be slow, but is
# necessary for visualization or evaluation on a subset.
if self.episodes is not None:
# Get the unique set of parquet files for the requested episodes
# Path for episode-specific loading (e.g., visualization)
fpaths = set()
for ep_idx in self.episodes:
# Need to read metadata to find the file path for this episode
# Use the pre-loaded metadata list
ep_meta = self.episodes_metadata_list[ep_idx]
chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"]
@@ -878,8 +880,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"parquet", data_files=data_files, features=features, split="train"
)
# Filter the loaded dataset to *only* include the requested episodes
# This is necessary because v3 files can contain multiple episodes.
requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
@@ -887,7 +887,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# We must use `data_dir` to trigger the v2.1-style efficient cache.
# Use `data_dir` to trigger the v2.1-style efficient cache.
data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")