From 17a690e0481ac4248f68b40a033b30a8b061bf99 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 25 Jun 2026 17:04:48 +0200 Subject: [PATCH] fix(TIFF vs. pytorch): adding an extra uint16 to float32 conversion for depth maps stored as raw TIFF images --- src/lerobot/datasets/io_utils.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index eaae0d40f..868a114f5 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -248,12 +248,28 @@ def load_image_as_numpy( return img_array +# PIL modes for 16-bit unsigned depth maps. +UINT16_PIL_MODES = {"I;16", "I;16B", "I;16L"} + + +def pil_to_chw_tensor(img: PILImage.Image) -> torch.Tensor: + """Convert a PIL image to a channel-first tensor. + + ``uint16`` depth maps become ``float32 (1, H, W)`` in native units (``ToTensor`` + would overflow them to ``int16``); all other modes use the standard ``ToTensor`` path. + """ + if img.mode in UINT16_PIL_MODES: + return torch.from_numpy(np.array(img, dtype=np.float32))[None, ...] + return transforms.ToTensor()(img) + + def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: """Convert a batch from a Hugging Face dataset to torch tensors. This transform function converts items from Hugging Face dataset format (pyarrow) - to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) - to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + to torch tensors. RGB images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Depth + maps are returned as float32 (1, H, W) in their native units. Other types are converted to torch.tensor. Args: @@ -268,8 +284,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to continue first_item = items_dict[key][0] if isinstance(first_item, PILImage.Image): - to_tensor = transforms.ToTensor() - items_dict[key] = [to_tensor(img) for img in items_dict[key]] + items_dict[key] = [pil_to_chw_tensor(img) for img in items_dict[key]] elif first_item is None or isinstance(first_item, dict): pass else: @@ -335,7 +350,11 @@ def item_to_torch(item: dict) -> dict: """ skip_keys = {"task", *LANGUAGE_COLUMNS} for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in skip_keys: + if key in skip_keys: + continue + if isinstance(val, PILImage.Image): + item[key] = pil_to_chw_tensor(val) + elif isinstance(val, (np.ndarray | list)): # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item