try fix 5

This commit is contained in:
Steven Palma
2025-11-05 21:34:10 +01:00
parent 20333abc72
commit bbcffc4999
+13 -19
View File
@@ -80,30 +80,24 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. ...
[This is the v2.1 item-level transform]
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
if items_dict[key] is None:
continue
if isinstance(items_dict[key], PILImage.Image):
# PIL image (h w c) (uint8)
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
items_dict[key] = to_tensor(items_dict[key])
elif isinstance(items_dict[key], str):
# keep as is
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
# This handles tensors, ints, floats, etc.
items_dict[key] = torch.tensor(items_dict[key])
return items_dict