feat(policy): use pretrained vision encoder weights by default for diffusion and vqbet (#3202)

* feat: add pretrained vision encoder weights for diffusion and vqbet

* fix test by re-generating artifacts

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Ville Kuosmanen
2026-05-07 11:10:38 +01:00
committed by GitHub
parent a0e52d52fe
commit eaf0218bc8
6 changed files with 13 additions and 13 deletions
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure. # Inputs / output structure.
n_obs_steps: int = 2 n_obs_steps: int = 2
horizon: int = 16 horizon: int = 64
n_action_steps: int = 8 n_action_steps: int = 32
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
crop_ratio: float = 1.0 crop_ratio: float = 1.0
crop_shape: tuple[int, int] | None = None crop_shape: tuple[int, int] | None = None
crop_is_random: bool = True crop_is_random: bool = True
pretrained_backbone_weights: str | None = None pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = True use_group_norm: bool = False
spatial_softmax_num_keypoints: int = 32 spatial_softmax_num_keypoints: int = 32
use_separate_rgb_encoder_per_camera: bool = False use_separate_rgb_encoder_per_camera: bool = True
# Unet. # Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048) down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5 kernel_size: int = 5
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
vision_backbone: str = "resnet18" vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84) crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True crop_is_random: bool = True
pretrained_backbone_weights: str | None = None pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
use_group_norm: bool = True use_group_norm: bool = False
spatial_softmax_num_keypoints: int = 32 spatial_softmax_num_keypoints: int = 32
# VQ-VAE # VQ-VAE
n_vqvae_training_steps: int = 20000 n_vqvae_training_steps: int = 20000
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b oid sha256:51effd76b73e972f10d31f5084ab906386134b600c87b2668767d30232a902bd
size 992 size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013 oid sha256:d4d7a16ca67f9adefac0e0620a7b2e9c822f2db42faaaced7a89fbad60e5ead4
size 47424 size 47680
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819 oid sha256:796c439ee8a64bf9901ff8325e7419bda8bd316360ee95e6304e8e1ae0f4c36c
size 68 size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6 oid sha256:ad33a8b47c39c2e1374567ff9da43cdb95e2dbe904c1b02a35051346d3043095
size 47424 size 47680