try fix 9

This commit is contained in:
Steven Palma
2025-11-05 21:52:15 +01:00
parent 76f25f6afd
commit bcc13f1d90
+32 -15
View File
@@ -80,31 +80,48 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""
Converts a single item (row) from a Hugging Face dataset to torch tensors.
This item-level transform allows `datasets` to build an efficient, pre-processed cache.
Converts a batch from a Hugging Face dataset to torch tensors.
"""
for key in items_dict:
item = items_dict[key]
if item is None:
# Create a single ToTensor transform instance to reuse
to_tensor = transforms.ToTensor()
for key in items_dict:
items_list = items_dict[key]
# Check if the list is non-empty
if not items_list:
continue
if isinstance(item, PILImage.Image):
# Correctly transform PIL images to tensors
items_dict[key] = transforms.ToTensor()(item)
elif isinstance(item, (str, bytes)):
# Let strings (like 'task') pass through untouched
first_item = items_list[0]
if isinstance(first_item, PILImage.Image):
# This is the (slow) CPU-bound part.
# We convert every image in the batch list to a tensor.
items_dict[key] = [to_tensor(img) for img in items_list]
elif isinstance(first_item, (str, bytes)):
# List of strings (e.g., 'task'), do nothing
pass
elif first_item is None:
# List of Nones, do nothing
pass
else:
# Convert all other numeric types (int, float, list, np.ndarray) to tensors
# List of other things (int, float, list, np.ndarray)
try:
items_dict[key] = torch.tensor(item)
# Convert each item in the list to a tensor
items_dict[key] = [torch.tensor(item) for item in items_list]
except Exception as e:
# Catch errors like the one you saw
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
# This catch is what was missing from the original v3.0 code
print(
f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}"
)
raise e
return items_dict