From eaf0218bc8ab12e2119b8b837c9b66eb9a7e3823 Mon Sep 17 00:00:00 2001 From: Ville Kuosmanen Date: Thu, 7 May 2026 11:10:38 +0100 Subject: [PATCH] feat(policy): use pretrained vision encoder weights by default for diffusion and vqbet (#3202) * feat: add pretrained vision encoder weights for diffusion and vqbet * fix test by re-generating artifacts --------- Co-authored-by: Steven Palma --- .../policies/diffusion/configuration_diffusion.py | 10 +++++----- src/lerobot/policies/vqbet/configuration_vqbet.py | 4 ++-- .../policies/pusht_diffusion_/actions.safetensors | 2 +- .../policies/pusht_diffusion_/grad_stats.safetensors | 4 ++-- .../policies/pusht_diffusion_/output_dict.safetensors | 2 +- .../policies/pusht_diffusion_/param_stats.safetensors | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8e3d4bf19..ed04ab54d 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig): # Inputs / output structure. n_obs_steps: int = 2 - horizon: int = 16 - n_action_steps: int = 8 + horizon: int = 64 + n_action_steps: int = 32 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig): crop_ratio: float = 1.0 crop_shape: tuple[int, int] | None = None crop_is_random: bool = True - pretrained_backbone_weights: str | None = None - use_group_norm: bool = True + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + use_group_norm: bool = False spatial_softmax_num_keypoints: int = 32 - use_separate_rgb_encoder_per_camera: bool = False + use_separate_rgb_encoder_per_camera: bool = True # Unet. down_dims: tuple[int, ...] = (512, 1024, 2048) kernel_size: int = 5 diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index d02745321..e5c1754e8 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig): vision_backbone: str = "resnet18" crop_shape: tuple[int, int] | None = (84, 84) crop_is_random: bool = True - pretrained_backbone_weights: str | None = None - use_group_norm: bool = True + pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" + use_group_norm: bool = False spatial_softmax_num_keypoints: int = 32 # VQ-VAE n_vqvae_training_steps: int = 20000 diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index 70b1411ab..65b8d5ca5 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b +oid sha256:51effd76b73e972f10d31f5084ab906386134b600c87b2668767d30232a902bd size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index bea7d4f19..4a5593a46 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013 -size 47424 +oid sha256:d4d7a16ca67f9adefac0e0620a7b2e9c822f2db42faaaced7a89fbad60e5ead4 +size 47680 diff --git a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors index 20cc4f547..f47997b8f 100644 --- a/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819 +oid sha256:796c439ee8a64bf9901ff8325e7419bda8bd316360ee95e6304e8e1ae0f4c36c size 68 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index 365a453dd..104a05f96 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6 -size 47424 +oid sha256:ad33a8b47c39c2e1374567ff9da43cdb95e2dbe904c1b02a35051346d3043095 +size 47680