From baf9b5036586f3667c6f5310d30396b4b233a801 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 27 Feb 2026 17:44:53 +0100 Subject: [PATCH] Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046) * refactor(diffusion): replace crop_shape with resize_shape and crop_ratio * fix(diffusion): address review feedback on resize/crop backward compat * test: regenerate diffusion artifacts for updated default config * fix: disable crop when resize path uses crop_ratio=1.0 --------- Co-authored-by: starlitxiling <1754165401@qq.com> --- .../diffusion/configuration_diffusion.py | 44 +++++++++++++++---- .../policies/diffusion/modeling_diffusion.py | 28 ++++++++---- .../pusht_diffusion_/actions.safetensors | 2 +- .../pusht_diffusion_/grad_stats.safetensors | 2 +- .../pusht_diffusion_/output_dict.safetensors | 2 +- .../pusht_diffusion_/param_stats.safetensors | 2 +- 6 files changed, 59 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 3d30e0941..91b3df214 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig): normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. - crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit - within the image size. If None, no cropping is done. - crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval - mode). + resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision + backbone. If None, no resizing is done and the original image resolution is used. + crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape + (crop_h = int(resize_shape[0] * crop_ratio), likewise for width). + Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None. + crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0, + this is computed automatically. Can also be set directly for legacy configs that use + crop-only (without resize). If None and no derivation applies, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center + crop in eval mode). pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. `None` means no pretrained weights. use_group_norm: Whether to replace batch normalization with group normalization in the backbone. @@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig): # Architecture / modeling. # Vision backbone. vision_backbone: str = "resnet18" - crop_shape: tuple[int, int] | None = (84, 84) + resize_shape: tuple[int, int] | None = None + crop_ratio: float = 1.0 + crop_shape: tuple[int, int] | None = None crop_is_random: bool = True pretrained_backbone_weights: str | None = None use_group_norm: bool = True @@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig): f"Got {self.noise_scheduler_type}." ) + if self.resize_shape is not None and ( + len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape) + ): + raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.") + if not (0 < self.crop_ratio <= 1.0): + raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.") + + if self.resize_shape is not None: + if self.crop_ratio < 1.0: + self.crop_shape = ( + int(self.resize_shape[0] * self.crop_ratio), + int(self.resize_shape[1] * self.crop_ratio), + ) + else: + # Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0. + self.crop_shape = None + if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0): + raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.") + # Check that the horizon size and U-Net downsampling is compatible. # U-Net downsamples by 2 with each stage. downsampling_factor = 2 ** len(self.down_dims) @@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig): if len(self.image_features) == 0 and self.env_state_feature is None: raise ValueError("You must provide at least one image or the environment state among the inputs.") - if self.crop_shape is not None: + if self.resize_shape is None and self.crop_shape is not None: for key, image_ft in self.image_features.items(): if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]: raise ValueError( - f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} " - f"for `crop_shape` and {image_ft.shape} for " - f"`{key}`." + f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} " + f"for `crop_shape` and {image_ft.shape} for `{key}`." ) # Check that all input images have the same shape. diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 314ca369c..aa8d5dd14 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() # Set up optional preprocessing. - if config.crop_shape is not None: + if config.resize_shape is not None: + self.resize = torchvision.transforms.Resize(config.resize_shape) + else: + self.resize = None + + crop_shape = config.crop_shape + if crop_shape is not None: self.do_crop = True # Always use center crop for eval - self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape) else: self.maybe_random_crop = self.center_crop else: @@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. - # The dummy input should take the number of image channels from `config.image_features` and it should - # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the - # height and width from `config.image_features`. + # The dummy shape mirrors the runtime preprocessing order: resize -> crop. # Note: we have a check in the config class to make sure all images have the same shape. images_shape = next(iter(config.image_features.values())).shape - dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:] + if config.crop_shape is not None: + dummy_shape_h_w = config.crop_shape + elif config.resize_shape is not None: + dummy_shape_h_w = config.resize_shape + else: + dummy_shape_h_w = images_shape[1:] dummy_shape = (1, images_shape[0], *dummy_shape_h_w) feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:] @@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module): Returns: (B, D) image feature. """ - # Preprocess: maybe crop (if it was set up in the __init__). + # Preprocess: resize if configured, then crop if configured. + + if self.resize is not None: + x = self.resize(x) if self.do_crop: if self.training: # noqa: SIM108 x = self.maybe_random_crop(x) diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index ef581727d..70b1411ab 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 +oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index e00ed3238..bea7d4f19 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 +oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013 size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors index f29303992..20cc4f547 100644 --- a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5 +oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819 size 68 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index 614cc754e..365a453dd 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf +oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6 size 47424