make default params more aligned with paper and pretrained models

- adding possibility of freezing qwen backbone and world model
- added tests for weight loading
This commit is contained in:
Maximellerbach
2026-05-18 12:55:55 +02:00
parent cee80daa88
commit e9c171ead0
4 changed files with 128 additions and 73 deletions
+19 -29
View File
@@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO")
Key parameters in `VLAJEPAConfig`: Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description | | Parameter | Default | Description |
| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- | | ------------------------- | ------- | -------------------------------------------------------------- |
| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant | | `chunk_size` | 16 | Number of actions predicted per inference call |
| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder | | `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
| `chunk_size` | 16 | Number of actions predicted per inference call | | `num_video_frames` | 16 | Video clip length fed to the world model |
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning | | `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `num_video_frames` | 16 | Video clip length fed to the world model | | `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) | | `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim | | `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) |
| `action_num_layers` | 12 | Number of DiT transformer blocks |
| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence |
| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head |
| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) |
| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence |
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `predictor_depth` | 6 | Number of transformer blocks in the video predictor |
| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) |
--- ---
@@ -120,10 +109,19 @@ lerobot-train \
```bash ```bash
lerobot-train \ lerobot-train \
policy.path=lerobot/VLA-JEPA-LIBERO \ policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset dataset.repo_id=your_org/your_dataset
``` ```
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset \
policy.freeze_qwen=true
```
### Reproducing the LIBERO results ### Reproducing the LIBERO results
**Training on LIBERO:** **Training on LIBERO:**
@@ -138,14 +136,6 @@ TODO(Maxime):
lerobot-train \ lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \ policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=lerobot/libero_10 \ dataset.repo_id=lerobot/libero_10 \
policy.chunk_size=7 \
policy.n_action_steps=7 \
policy.future_action_window_size=6 \
policy.num_video_frames=8 \
policy.num_action_tokens_per_timestep=8 \
policy.num_embodied_action_tokens_per_instruction=32 \
policy.action_num_layers=16 \
policy.predictor_depth=12 \
training.num_steps=50000 \ training.num_steps=50000 \
env.type=libero \ env.type=libero \
env.task=libero_10 env.task=libero_10
@@ -12,8 +12,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@dataclass @dataclass
class VLAJEPAConfig(PreTrainedConfig): class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 16 chunk_size: int = 7
n_action_steps: int = 16 n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
@@ -25,26 +25,26 @@ class VLAJEPAConfig(PreTrainedConfig):
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct" qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
tokenizer_padding_side: str = "left" tokenizer_padding_side: str = "left"
prompt_template: str = ( prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
)
special_action_token: str = "<|action_{}|>" special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>" embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7 action_dim: int = 7
state_dim: int = 8 state_dim: int = 8
future_action_window_size: int = 15 future_action_window_size: int = 6
past_action_window_size: int = 0 past_action_window_size: int = 0
num_action_tokens_per_timestep: int = 4 num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 8 num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 10 num_inference_timesteps: int = 4
action_hidden_size: int = 1024 action_hidden_size: int = 1024
action_model_type: str = "DiT-B" action_model_type: str = "DiT-B"
action_num_layers: int = 12 action_num_layers: int = 16
action_dropout: float = 0.1 action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000 action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5 action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0 action_noise_beta_beta: float = 1.0
@@ -53,15 +53,14 @@ class VLAJEPAConfig(PreTrainedConfig):
action_max_seq_len: int = 1024 action_max_seq_len: int = 1024
# total video frames loaded per sample # total video frames loaded per sample
num_video_frames: int = 16 num_video_frames: int = 8
predictor_depth: int = 6 predictor_depth: int = 12
predictor_num_heads: int = 8 predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0 predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0 predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1 world_model_loss_weight: float = 0.1
enable_world_model: bool = True
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256) jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style) repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None resize_images_to: tuple[int, int] | None = None
torch_dtype: str = "bfloat16" torch_dtype: str = "bfloat16"
@@ -77,6 +76,9 @@ class VLAJEPAConfig(PreTrainedConfig):
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.") raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.future_action_window_size + 1 > self.chunk_size: if self.future_action_window_size + 1 > self.chunk_size:
@@ -86,6 +86,9 @@ class VLAJEPAModel(nn.Module):
self.video_processor = None self.video_processor = None
self.video_predictor = None self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders. # Build prompt placeholders.
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor. # Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
# This matches the number of context temporal positions after tubelet compression. # This matches the number of context temporal positions after tubelet compression.
+89 -29
View File
@@ -36,6 +36,11 @@ from lerobot.utils.constants import ACTION # noqa: E402
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO" PRETRAINED_SUBFOLDER = "LIBERO"
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Core training / inference tests # Core training / inference tests
@@ -259,39 +264,94 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Pretrained checkpoint # Pretrained checkpoint
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_pretrained_checkpoint_loads_from_hf_cache() -> None: def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
import torch """Build a training batch whose keys/shapes match a hub-loaded policy config."""
from huggingface_hub import hf_hub_download cfg = policy.config
from huggingface_hub.errors import LocalEntryNotFoundError batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
return batch
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
checkpoint_filename = os.environ.get(
"VLA_JEPA_PRETRAINED_CHECKPOINT",
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
)
try: def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
checkpoint_path = hf_hub_download( """Build an inference batch whose keys/shapes match a hub-loaded policy config."""
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True cfg = policy.config
) batch: dict = {"task": ["pick up the cube"] * batch_size}
except LocalEntryNotFoundError: for key, feat in cfg.image_features.items():
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.") h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
return batch
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = ( _CP_ROOT = "lerobot" # TODO: upload converted checkpoints
checkpoint.get("state_dict")
or checkpoint.get("model_state_dict") # Each tuple: (repo_id, enable_world_model)
or checkpoint.get("model") _HUB_VARIANTS = [
or checkpoint (f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
) (f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
assert isinstance(state_dict, dict) (f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
assert len(state_dict) > 0 ]
assert all(isinstance(k, str) for k in list(state_dict)[:10])
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
"""Policy loads from the converted safetensors checkpoint on the Hub."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
assert policy.config.enable_world_model == enable_world_model
assert sum(p.numel() for p in policy.parameters()) > 0
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
policy.train()
batch = _make_hub_train_batch(policy)
loss, logs = policy.forward(batch)
assert torch.isfinite(loss)
assert "action_loss" in logs
if enable_world_model:
assert "wm_loss" in logs
@extended_test
def test_hub_freeze_qwen_disables_world_model() -> None:
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
qwen_params = list(policy.model.qwen.parameters())
assert all(not p.requires_grad for p in qwen_params)
assert any(p.requires_grad for p in policy.model.action_model.parameters())
@extended_test
def test_hub_disable_world_model_loads_simpler_env() -> None:
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
assert policy.model.video_encoder is None
@extended_test
def test_hub_libero_inference_shape() -> None:
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
policy.eval()
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim