fix(TIFF vs. pytorch): adding an extra uint16 to float32 conversion for depth maps stored as raw TIFF images

This commit is contained in:
CarolinePascal
2026-06-25 17:04:48 +02:00
parent 0e39bae335
commit 17a690e048
+24 -5
View File
@@ -248,12 +248,28 @@ def load_image_as_numpy(
return img_array
# PIL modes for 16-bit unsigned depth maps.
UINT16_PIL_MODES = {"I;16", "I;16B", "I;16L"}
def pil_to_chw_tensor(img: PILImage.Image) -> torch.Tensor:
"""Convert a PIL image to a channel-first tensor.
``uint16`` depth maps become ``float32 (1, H, W)`` in native units (``ToTensor``
would overflow them to ``int16``); all other modes use the standard ``ToTensor`` path.
"""
if img.mode in UINT16_PIL_MODES:
return torch.from_numpy(np.array(img, dtype=np.float32))[None, ...]
return transforms.ToTensor()(img)
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
to torch tensors. RGB images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Depth
maps are returned as float32 (1, H, W) in their native units. Other
types are converted to torch.tensor.
Args:
@@ -268,8 +284,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
items_dict[key] = [pil_to_chw_tensor(img) for img in items_dict[key]]
elif first_item is None or isinstance(first_item, dict):
pass
else:
@@ -335,7 +350,11 @@ def item_to_torch(item: dict) -> dict:
"""
skip_keys = {"task", *LANGUAGE_COLUMNS}
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
if key in skip_keys:
continue
if isinstance(val, PILImage.Image):
item[key] = pil_to_chw_tensor(val)
elif isinstance(val, (np.ndarray | list)):
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item