mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Fix(diffusion): enforce no-crop behavior when crop_ratio=1.0 (#3046)
* refactor(diffusion): replace crop_shape with resize_shape and crop_ratio * fix(diffusion): address review feedback on resize/crop backward compat * test: regenerate diffusion artifacts for updated default config * fix: disable crop when resize path uses crop_ratio=1.0 --------- Co-authored-by: starlitxiling <1754165401@qq.com>
This commit is contained in:
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
Reference in New Issue
Block a user