mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
try fix 9
This commit is contained in:
@@ -80,31 +80,48 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""
|
||||
Converts a single item (row) from a Hugging Face dataset to torch tensors.
|
||||
This item-level transform allows `datasets` to build an efficient, pre-processed cache.
|
||||
Converts a batch from a Hugging Face dataset to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
item = items_dict[key]
|
||||
|
||||
if item is None:
|
||||
# Create a single ToTensor transform instance to reuse
|
||||
to_tensor = transforms.ToTensor()
|
||||
|
||||
for key in items_dict:
|
||||
items_list = items_dict[key]
|
||||
|
||||
# Check if the list is non-empty
|
||||
if not items_list:
|
||||
continue
|
||||
|
||||
if isinstance(item, PILImage.Image):
|
||||
# Correctly transform PIL images to tensors
|
||||
items_dict[key] = transforms.ToTensor()(item)
|
||||
elif isinstance(item, (str, bytes)):
|
||||
# Let strings (like 'task') pass through untouched
|
||||
first_item = items_list[0]
|
||||
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
# This is the (slow) CPU-bound part.
|
||||
# We convert every image in the batch list to a tensor.
|
||||
items_dict[key] = [to_tensor(img) for img in items_list]
|
||||
|
||||
elif isinstance(first_item, (str, bytes)):
|
||||
# List of strings (e.g., 'task'), do nothing
|
||||
pass
|
||||
|
||||
elif first_item is None:
|
||||
# List of Nones, do nothing
|
||||
pass
|
||||
|
||||
else:
|
||||
# Convert all other numeric types (int, float, list, np.ndarray) to tensors
|
||||
# List of other things (int, float, list, np.ndarray)
|
||||
try:
|
||||
items_dict[key] = torch.tensor(item)
|
||||
# Convert each item in the list to a tensor
|
||||
items_dict[key] = [torch.tensor(item) for item in items_list]
|
||||
except Exception as e:
|
||||
# Catch errors like the one you saw
|
||||
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
|
||||
# This catch is what was missing from the original v3.0 code
|
||||
print(
|
||||
f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return items_dict
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user