Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)

* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio

* fix(diffusion): address review feedback on resize/crop backward compat

* test: regenerate diffusion artifacts for updated default config

* fix: disable crop when resize path uses crop_ratio=1.0

---------

Co-authored-by: starlitxiling <1754165401@qq.com>
This commit is contained in:
Khalil Meftah
2026-02-27 17:44:53 +01:00
committed by GitHub
parent a0fdbf037a
commit baf9b50365
6 changed files with 59 additions and 21 deletions
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
vision_backbone: Name of the torchvision resnet backbone to use for encoding images. vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
within the image size. If None, no cropping is done. backbone. If None, no resizing is done and the original image resolution is used.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
mode). (crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
this is computed automatically. Can also be set directly for legacy configs that use
crop-only (without resize). If None and no derivation applies, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center
crop in eval mode).
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
`None` means no pretrained weights. `None` means no pretrained weights.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone. use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
vision_backbone: str = "resnet18" vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84) resize_shape: tuple[int, int] | None = None
crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True crop_is_random: bool = True
pretrained_backbone_weights: str | None = None pretrained_backbone_weights: str | None = None
use_group_norm: bool = True use_group_norm: bool = True
@@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}." f"Got {self.noise_scheduler_type}."
) )
if self.resize_shape is not None and (
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
):
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
if not (0 < self.crop_ratio <= 1.0):
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
if self.resize_shape is not None:
if self.crop_ratio < 1.0:
self.crop_shape = (
int(self.resize_shape[0] * self.crop_ratio),
int(self.resize_shape[1] * self.crop_ratio),
)
else:
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
self.crop_shape = None
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
# Check that the horizon size and U-Net downsampling is compatible. # Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage. # U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims) downsampling_factor = 2 ** len(self.down_dims)
@@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
if len(self.image_features) == 0 and self.env_state_feature is None: if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.") raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None: if self.resize_shape is None and self.crop_shape is not None:
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError( raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for " f"for `crop_shape` and {image_ft.shape} for `{key}`."
f"`{key}`."
) )
# Check that all input images have the same shape. # Check that all input images have the same shape.
@@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
# Set up optional preprocessing. # Set up optional preprocessing.
if config.crop_shape is not None: if config.resize_shape is not None:
self.resize = torchvision.transforms.Resize(config.resize_shape)
else:
self.resize = None
crop_shape = config.crop_shape
if crop_shape is not None:
self.do_crop = True self.do_crop = True
# Always use center crop for eval # Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
if config.crop_is_random: if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
else: else:
self.maybe_random_crop = self.center_crop self.maybe_random_crop = self.center_crop
else: else:
@@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.image_features` and it should # The dummy shape mirrors the runtime preprocessing order: resize -> crop.
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape. # Note: we have a check in the config class to make sure all images have the same shape.
images_shape = next(iter(config.image_features.values())).shape images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] if config.crop_shape is not None:
dummy_shape_h_w = config.crop_shape
elif config.resize_shape is not None:
dummy_shape_h_w = config.resize_shape
else:
dummy_shape_h_w = images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w) dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
@@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
Returns: Returns:
(B, D) image feature. (B, D) image feature.
""" """
# Preprocess: maybe crop (if it was set up in the __init__). # Preprocess: resize if configured, then crop if configured.
if self.resize is not None:
x = self.resize(x)
if self.do_crop: if self.do_crop:
if self.training: # noqa: SIM108 if self.training: # noqa: SIM108
x = self.maybe_random_crop(x) x = self.maybe_random_crop(x)
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
size 992 size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
size 47424 size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
size 68 size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
size 47424 size 47424