mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
small fix for the preprocessor and padded images
This commit is contained in:
@@ -55,19 +55,23 @@ class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
|
||||
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
|
||||
# matches on the `"image"` substring, so set these aside and restore them untouched rather
|
||||
# than letting it try to resize a mask.
|
||||
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
|
||||
leads: dict[str, tuple] = {}
|
||||
flat_input = dict(observation)
|
||||
for key, img in observation.items():
|
||||
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
|
||||
for key, img in list(flat_input.items()):
|
||||
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
|
||||
leads[key] = tuple(img.shape[:-3])
|
||||
flat_input[key] = img.reshape(-1, *img.shape[-3:])
|
||||
processed = super().observation(flat_input)
|
||||
if not leads:
|
||||
return processed
|
||||
out = dict(processed)
|
||||
for key, lead in leads.items():
|
||||
im = processed[key]
|
||||
out[key] = im.reshape(*lead, *im.shape[-3:])
|
||||
out.update(pad_keys)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user