diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 72c277b67..295f67b48 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -80,31 +80,48 @@ from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" -def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]: +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: """ - Converts a single item (row) from a Hugging Face dataset to torch tensors. - This item-level transform allows `datasets` to build an efficient, pre-processed cache. + Converts a batch from a Hugging Face dataset to torch tensors. """ - for key in items_dict: - item = items_dict[key] - if item is None: + # Create a single ToTensor transform instance to reuse + to_tensor = transforms.ToTensor() + + for key in items_dict: + items_list = items_dict[key] + + # Check if the list is non-empty + if not items_list: continue - if isinstance(item, PILImage.Image): - # Correctly transform PIL images to tensors - items_dict[key] = transforms.ToTensor()(item) - elif isinstance(item, (str, bytes)): - # Let strings (like 'task') pass through untouched + first_item = items_list[0] + + if isinstance(first_item, PILImage.Image): + # This is the (slow) CPU-bound part. + # We convert every image in the batch list to a tensor. + items_dict[key] = [to_tensor(img) for img in items_list] + + elif isinstance(first_item, (str, bytes)): + # List of strings (e.g., 'task'), do nothing pass + + elif first_item is None: + # List of Nones, do nothing + pass + else: - # Convert all other numeric types (int, float, list, np.ndarray) to tensors + # List of other things (int, float, list, np.ndarray) try: - items_dict[key] = torch.tensor(item) + # Convert each item in the list to a tensor + items_dict[key] = [torch.tensor(item) for item in items_list] except Exception as e: - # Catch errors like the one you saw - print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}") + # This catch is what was missing from the original v3.0 code + print( + f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}" + ) raise e + return items_dict