make default params more aligned with paper and pretrained models

- adding possibility of freezing qwen backbone and world model
- added tests for weight loading
This commit is contained in:
Maximellerbach
2026-05-18 12:55:55 +02:00
parent cee80daa88
commit e9c171ead0
4 changed files with 128 additions and 73 deletions
+19 -29
View File
@@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO")
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- |
| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant |
| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder |
| `chunk_size` | 16 | Number of actions predicted per inference call |
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 16 | Video clip length fed to the world model |
| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) |
| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim |
| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) |
| `action_num_layers` | 12 | Number of DiT transformer blocks |
| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence |
| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head |
| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) |
| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence |
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `predictor_depth` | 6 | Number of transformer blocks in the video predictor |
| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) |
| Parameter | Default | Description |
| ------------------------- | ------- | -------------------------------------------------------------- |
| `chunk_size` | 16 | Number of actions predicted per inference call |
| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 16 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 10 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
---
@@ -120,10 +109,19 @@ lerobot-train \
```bash
lerobot-train \
policy.path=lerobot/VLA-JEPA-LIBERO \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset
```
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=your_org/your_dataset \
policy.freeze_qwen=true
```
### Reproducing the LIBERO results
**Training on LIBERO:**
@@ -138,14 +136,6 @@ TODO(Maxime):
lerobot-train \
policy.path=lerobot/VLA-JEPA-Pretrain \
dataset.repo_id=lerobot/libero_10 \
policy.chunk_size=7 \
policy.n_action_steps=7 \
policy.future_action_window_size=6 \
policy.num_video_frames=8 \
policy.num_action_tokens_per_timestep=8 \
policy.num_embodied_action_tokens_per_instruction=32 \
policy.action_num_layers=16 \
policy.predictor_depth=12 \
training.num_steps=50000 \
env.type=libero \
env.task=libero_10
@@ -12,8 +12,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 16
n_action_steps: int = 16
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -25,26 +25,26 @@ class VLAJEPAConfig(PreTrainedConfig):
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
tokenizer_padding_side: str = "left"
prompt_template: str = (
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
)
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
future_action_window_size: int = 15
future_action_window_size: int = 6
past_action_window_size: int = 0
num_action_tokens_per_timestep: int = 4
num_embodied_action_tokens_per_instruction: int = 8
num_inference_timesteps: int = 10
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 12
action_dropout: float = 0.1
action_num_layers: int = 16
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
@@ -53,15 +53,14 @@ class VLAJEPAConfig(PreTrainedConfig):
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 16
predictor_depth: int = 6
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
enable_world_model: bool = True
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
torch_dtype: str = "bfloat16"
@@ -77,6 +76,9 @@ class VLAJEPAConfig(PreTrainedConfig):
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.future_action_window_size + 1 > self.chunk_size:
@@ -86,6 +86,9 @@ class VLAJEPAModel(nn.Module):
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
# This matches the number of context temporal positions after tubelet compression.
+89 -29
View File
@@ -36,6 +36,11 @@ from lerobot.utils.constants import ACTION # noqa: E402
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO"
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
# ---------------------------------------------------------------------------
# Core training / inference tests
@@ -259,39 +264,94 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
# ---------------------------------------------------------------------------
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import LocalEntryNotFoundError
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
return batch
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
checkpoint_filename = os.environ.get(
"VLA_JEPA_PRETRAINED_CHECKPOINT",
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
)
try:
checkpoint_path = hf_hub_download(
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
)
except LocalEntryNotFoundError:
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
return batch
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = (
checkpoint.get("state_dict")
or checkpoint.get("model_state_dict")
or checkpoint.get("model")
or checkpoint
)
assert isinstance(state_dict, dict)
assert len(state_dict) > 0
assert all(isinstance(k, str) for k in list(state_dict)[:10])
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
# Each tuple: (repo_id, enable_world_model)
_HUB_VARIANTS = [
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
]
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
"""Policy loads from the converted safetensors checkpoint on the Hub."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
assert policy.config.enable_world_model == enable_world_model
assert sum(p.numel() for p in policy.parameters()) > 0
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
policy.train()
batch = _make_hub_train_batch(policy)
loss, logs = policy.forward(batch)
assert torch.isfinite(loss)
assert "action_loss" in logs
if enable_world_model:
assert "wm_loss" in logs
@extended_test
def test_hub_freeze_qwen_disables_world_model() -> None:
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
qwen_params = list(policy.model.qwen.parameters())
assert all(not p.requires_grad for p in qwen_params)
assert any(p.requires_grad for p in policy.model.action_model.parameters())
@extended_test
def test_hub_disable_world_model_loads_simpler_env() -> None:
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
assert policy.model.video_encoder is None
@extended_test
def test_hub_libero_inference_shape() -> None:
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
policy.eval()
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim