mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-27 05:07:15 +00:00
fix(TIFF vs. pytorch): adding an extra uint16 to float32 conversion for depth maps stored as raw TIFF images
This commit is contained in:
@@ -248,12 +248,28 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
# PIL modes for 16-bit unsigned depth maps.
|
||||
UINT16_PIL_MODES = {"I;16", "I;16B", "I;16L"}
|
||||
|
||||
|
||||
def pil_to_chw_tensor(img: PILImage.Image) -> torch.Tensor:
|
||||
"""Convert a PIL image to a channel-first tensor.
|
||||
|
||||
``uint16`` depth maps become ``float32 (1, H, W)`` in native units (``ToTensor``
|
||||
would overflow them to ``int16``); all other modes use the standard ``ToTensor`` path.
|
||||
"""
|
||||
if img.mode in UINT16_PIL_MODES:
|
||||
return torch.from_numpy(np.array(img, dtype=np.float32))[None, ...]
|
||||
return transforms.ToTensor()(img)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
to torch tensors. RGB images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Depth
|
||||
maps are returned as float32 (1, H, W) in their native units. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
@@ -268,8 +284,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
items_dict[key] = [pil_to_chw_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
@@ -335,7 +350,11 @@ def item_to_torch(item: dict) -> dict:
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
if key in skip_keys:
|
||||
continue
|
||||
if isinstance(val, PILImage.Image):
|
||||
item[key] = pil_to_chw_tensor(val)
|
||||
elif isinstance(val, (np.ndarray | list)):
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
Reference in New Issue
Block a user