From bbcffc4999ef25ef18f7ba723e32c47918d2a4ad Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 5 Nov 2025 21:34:10 +0100 Subject: [PATCH] try fix 5 --- src/lerobot/datasets/lerobot_dataset.py | 32 ++++++++++--------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 9484a4342..99e210220 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -80,30 +80,24 @@ from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" -def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Convert a batch from a Hugging Face dataset to torch tensors. - - This transform function converts items from Hugging Face dataset format (pyarrow) - to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) - to a torch image representation (C, H, W, float32) in the range [0, 1]. Other - types are converted to torch.tensor. - - Args: - items_dict (dict): A dictionary representing a batch of data from a - Hugging Face dataset. - - Returns: - dict: The batch with items converted to torch tensors. +def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]: + """Get a transform function that convert items from Hugging Face dataset (pyarrow) + to torch tensors. ... + [This is the v2.1 item-level transform] """ for key in items_dict: - first_item = items_dict[key][0] - if isinstance(first_item, PILImage.Image): + if items_dict[key] is None: + continue + if isinstance(items_dict[key], PILImage.Image): + # PIL image (h w c) (uint8) to_tensor = transforms.ToTensor() - items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif first_item is None: + items_dict[key] = to_tensor(items_dict[key]) + elif isinstance(items_dict[key], str): + # keep as is pass else: - items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + # This handles tensors, ints, floats, etc. + items_dict[key] = torch.tensor(items_dict[key]) return items_dict