small fix for the preprocessor and padded images

This commit is contained in:
Maxime Ellerbach
2026-06-16 11:27:51 +00:00
parent 1e762d5240
commit ecf342d481
@@ -55,19 +55,23 @@ class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""
def observation(self, observation: dict) -> dict:
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
# matches on the `"image"` substring, so set these aside and restore them untouched rather
# than letting it try to resize a mask.
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
leads: dict[str, tuple] = {}
flat_input = dict(observation)
for key, img in observation.items():
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
for key, img in list(flat_input.items()):
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
if not leads:
return processed
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
out.update(pad_keys)
return out