mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(policy): use pretrained vision encoder weights by default for diffusion and vqbet (#3202)
* feat: add pretrained vision encoder weights for diffusion and vqbet * fix test by re-generating artifacts --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
horizon: int = 16
|
horizon: int = 64
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 32
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
crop_ratio: float = 1.0
|
crop_ratio: float = 1.0
|
||||||
crop_shape: tuple[int, int] | None = None
|
crop_shape: tuple[int, int] | None = None
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = False
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
use_separate_rgb_encoder_per_camera: bool = False
|
use_separate_rgb_encoder_per_camera: bool = True
|
||||||
# Unet.
|
# Unet.
|
||||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
kernel_size: int = 5
|
kernel_size: int = 5
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
|||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
crop_shape: tuple[int, int] | None = (84, 84)
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = False
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
# VQ-VAE
|
# VQ-VAE
|
||||||
n_vqvae_training_steps: int = 20000
|
n_vqvae_training_steps: int = 20000
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
|
oid sha256:51effd76b73e972f10d31f5084ab906386134b600c87b2668767d30232a902bd
|
||||||
size 992
|
size 992
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
|
oid sha256:d4d7a16ca67f9adefac0e0620a7b2e9c822f2db42faaaced7a89fbad60e5ead4
|
||||||
size 47424
|
size 47680
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
|
oid sha256:796c439ee8a64bf9901ff8325e7419bda8bd316360ee95e6304e8e1ae0f4c36c
|
||||||
size 68
|
size 68
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
|
oid sha256:ad33a8b47c39c2e1374567ff9da43cdb95e2dbe904c1b02a35051346d3043095
|
||||||
size 47424
|
size 47680
|
||||||
|
|||||||
Reference in New Issue
Block a user