From e9c171ead00b4168e755d7033cd8aadf381228b7 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Mon, 18 May 2026 12:55:55 +0200 Subject: [PATCH] make default params more aligned with paper and pretrained models - adding possibility of freezing qwen backbone and world model - added tests for weight loading --- docs/source/policy_vla_jepa_README.md | 48 +++---- .../vla_jepa/configuration_vla_jepa.py | 32 ++--- .../policies/vla_jepa/modeling_vla_jepa.py | 3 + tests/policies/vla_jepa/test_vla_jepa.py | 118 +++++++++++++----- 4 files changed, 128 insertions(+), 73 deletions(-) diff --git a/docs/source/policy_vla_jepa_README.md b/docs/source/policy_vla_jepa_README.md index f704440a7..977b68879 100644 --- a/docs/source/policy_vla_jepa_README.md +++ b/docs/source/policy_vla_jepa_README.md @@ -82,26 +82,15 @@ policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO") Key parameters in `VLAJEPAConfig`: -| Parameter | Default | Description | -| -------------------------------------------- | ---------------------------------- | ------------------------------------------------------------------- | -| `qwen_model_name` | `"Qwen/Qwen3-VL-2B-Instruct"` | Qwen3-VL backbone variant | -| `jepa_encoder_name` | `"facebook/vjepa2-vitl-fpc64-256"` | V-JEPA2 video encoder | -| `chunk_size` | 16 | Number of actions predicted per inference call | -| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning | -| `num_video_frames` | 16 | Video clip length fed to the world model | -| `jepa_tubelet_size` | 2 | Temporal patch size of the video encoder (must match encoder) | -| `action_model_type` | `"DiT-B"` | DiT preset — controls hidden dim, heads, head dim | -| `action_hidden_size` | 1024 | DiT output projection size (and action decoder input size) | -| `action_num_layers` | 12 | Number of DiT transformer blocks | -| `num_target_vision_tokens` | 32 | Learned future-vision query tokens prepended to the action sequence | -| `action_max_seq_len` | 1024 | Max length of the positional embedding table in the action head | -| `num_action_tokens_per_timestep` | 4 | Special action tokens per temporal step (used for WM conditioning) | -| `num_embodied_action_tokens_per_instruction` | 8 | Instruction-level embodied tokens appended to the Qwen sequence | -| `num_inference_timesteps` | 10 | Euler integration steps for action denoising | -| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor | -| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss | -| `predictor_depth` | 6 | Number of transformer blocks in the video predictor | -| `repeated_diffusion_steps` | 4 | Independent noise draws per batch item (CogACT-style augmentation) | +| Parameter | Default | Description | +| ------------------------- | ------- | -------------------------------------------------------------- | +| `chunk_size` | 16 | Number of actions predicted per inference call | +| `n_action_steps` | 16 | Steps executed from the predicted chunk before re-planning | +| `num_video_frames` | 16 | Video clip length fed to the world model | +| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor | +| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss | +| `num_inference_timesteps` | 10 | Euler integration steps for action denoising | +| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head | --- @@ -120,10 +109,19 @@ lerobot-train \ ```bash lerobot-train \ - policy.path=lerobot/VLA-JEPA-LIBERO \ + policy.path=lerobot/VLA-JEPA-Pretrain \ dataset.repo_id=your_org/your_dataset ``` +If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`: + +```bash +lerobot-train \ + policy.path=lerobot/VLA-JEPA-Pretrain \ + dataset.repo_id=your_org/your_dataset \ + policy.freeze_qwen=true +``` + ### Reproducing the LIBERO results **Training on LIBERO:** @@ -138,14 +136,6 @@ TODO(Maxime): lerobot-train \ policy.path=lerobot/VLA-JEPA-Pretrain \ dataset.repo_id=lerobot/libero_10 \ - policy.chunk_size=7 \ - policy.n_action_steps=7 \ - policy.future_action_window_size=6 \ - policy.num_video_frames=8 \ - policy.num_action_tokens_per_timestep=8 \ - policy.num_embodied_action_tokens_per_instruction=32 \ - policy.action_num_layers=16 \ - policy.predictor_depth=12 \ training.num_steps=50000 \ env.type=libero \ env.task=libero_10 diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index b23594101..d2092fe83 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -12,8 +12,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @dataclass class VLAJEPAConfig(PreTrainedConfig): n_obs_steps: int = 1 - chunk_size: int = 16 - n_action_steps: int = 16 + chunk_size: int = 7 + n_action_steps: int = 7 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -25,26 +25,26 @@ class VLAJEPAConfig(PreTrainedConfig): qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct" jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" + freeze_qwen: bool = False + enable_world_model: bool = True tokenizer_padding_side: str = "left" - prompt_template: str = ( - "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}." - ) + prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}." special_action_token: str = "<|action_{}|>" embodied_action_token: str = "<|embodied_action|>" action_dim: int = 7 state_dim: int = 8 - future_action_window_size: int = 15 + future_action_window_size: int = 6 past_action_window_size: int = 0 - num_action_tokens_per_timestep: int = 4 - num_embodied_action_tokens_per_instruction: int = 8 - num_inference_timesteps: int = 10 + num_action_tokens_per_timestep: int = 8 + num_embodied_action_tokens_per_instruction: int = 32 + num_inference_timesteps: int = 4 action_hidden_size: int = 1024 action_model_type: str = "DiT-B" - action_num_layers: int = 12 - action_dropout: float = 0.1 + action_num_layers: int = 16 + action_dropout: float = 0.2 action_num_timestep_buckets: int = 1000 action_noise_beta_alpha: float = 1.5 action_noise_beta_beta: float = 1.0 @@ -53,15 +53,14 @@ class VLAJEPAConfig(PreTrainedConfig): action_max_seq_len: int = 1024 # total video frames loaded per sample - num_video_frames: int = 16 - predictor_depth: int = 6 + num_video_frames: int = 8 + predictor_depth: int = 12 predictor_num_heads: int = 8 predictor_mlp_ratio: float = 4.0 predictor_dropout: float = 0.0 world_model_loss_weight: float = 0.1 - enable_world_model: bool = True jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256) - repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style) + repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style) resize_images_to: tuple[int, int] | None = None torch_dtype: str = "bfloat16" @@ -77,6 +76,9 @@ class VLAJEPAConfig(PreTrainedConfig): def __post_init__(self) -> None: super().__post_init__() + if self.freeze_qwen and self.enable_world_model: + # freezing qwen backbone makes world model training irrelevant since no grad flows + self.enable_world_model = False if self.n_action_steps > self.chunk_size: raise ValueError("`n_action_steps` must be <= `chunk_size`.") if self.future_action_window_size + 1 > self.chunk_size: diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 448d34df8..7d6590bfd 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -86,6 +86,9 @@ class VLAJEPAModel(nn.Module): self.video_processor = None self.video_predictor = None + if config.freeze_qwen: + self.qwen.requires_grad_(False) + # Build prompt placeholders. # Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor. # This matches the number of context temporal positions after tubelet compression. diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 548fa236f..3b6e4a1a6 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -36,6 +36,11 @@ from lerobot.utils.constants import ACTION # noqa: E402 PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" PRETRAINED_SUBFOLDER = "LIBERO" +# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are +# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in. +_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0" +extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests") + # --------------------------------------------------------------------------- # Core training / inference tests @@ -259,39 +264,94 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> # --------------------------------------------------------------------------- # Pretrained checkpoint +# Hub tests (opt-in: VLA_JEPA_EXTENDED=1) # --------------------------------------------------------------------------- -def test_pretrained_checkpoint_loads_from_hf_cache() -> None: - import torch - from huggingface_hub import hf_hub_download - from huggingface_hub.errors import LocalEntryNotFoundError +def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict: + """Build a training batch whose keys/shapes match a hub-loaded policy config.""" + cfg = policy.config + batch: dict = {"task": ["pick up the cube"] * batch_size} + for key, feat in cfg.image_features.items(): + h, w = feat.shape[-2], feat.shape[-1] + batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w) + if cfg.robot_state_feature is not None: + batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0]) + batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim) + return batch - repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID) - subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/") - checkpoint_filename = os.environ.get( - "VLA_JEPA_PRETRAINED_CHECKPOINT", - f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt", - ) - try: - checkpoint_path = hf_hub_download( - repo_id=repo_id, filename=checkpoint_filename, local_files_only=True - ) - except LocalEntryNotFoundError: - pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.") +def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict: + """Build an inference batch whose keys/shapes match a hub-loaded policy config.""" + cfg = policy.config + batch: dict = {"task": ["pick up the cube"] * batch_size} + for key, feat in cfg.image_features.items(): + h, w = feat.shape[-2], feat.shape[-1] + batch[key] = torch.rand(batch_size, 3, h, w) + if cfg.robot_state_feature is not None: + batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0]) + return batch - try: - checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False) - except TypeError: - checkpoint = torch.load(checkpoint_path, map_location="cpu") - state_dict = ( - checkpoint.get("state_dict") - or checkpoint.get("model_state_dict") - or checkpoint.get("model") - or checkpoint - ) - assert isinstance(state_dict, dict) - assert len(state_dict) > 0 - assert all(isinstance(k, str) for k in list(state_dict)[:10]) +_CP_ROOT = "lerobot" # TODO: upload converted checkpoints + +# Each tuple: (repo_id, enable_world_model) +_HUB_VARIANTS = [ + (f"{_CP_ROOT}/VLA-JEPA-LIBERO", True), + (f"{_CP_ROOT}/VLA-JEPA-Pretrain", True), + (f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False), +] + + +@extended_test +@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS) +def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None: + """Policy loads from the converted safetensors checkpoint on the Hub.""" + policy = VLAJEPAPolicy.from_pretrained(repo_id) + assert policy.config.enable_world_model == enable_world_model + assert sum(p.numel() for p in policy.parameters()) > 0 + + +@extended_test +@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS) +def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None: + """Policy loaded from hub produces finite losses with a correctly-shaped batch.""" + policy = VLAJEPAPolicy.from_pretrained(repo_id) + policy.train() + + batch = _make_hub_train_batch(policy) + loss, logs = policy.forward(batch) + assert torch.isfinite(loss) + assert "action_loss" in logs + if enable_world_model: + assert "wm_loss" in logs + + +@extended_test +def test_hub_freeze_qwen_disables_world_model() -> None: + """freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"]) + assert not policy.config.enable_world_model + assert policy.model.video_predictor is None + qwen_params = list(policy.model.qwen.parameters()) + assert all(not p.requires_grad for p in qwen_params) + assert any(p.requires_grad for p in policy.model.action_model.parameters()) + + +@extended_test +def test_hub_disable_world_model_loads_simpler_env() -> None: + """SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv") + assert not policy.config.enable_world_model + assert policy.model.video_predictor is None + assert policy.model.video_encoder is None + + +@extended_test +def test_hub_libero_inference_shape() -> None: + """select_action returns the expected shape using the LIBERO hub checkpoint.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO") + policy.eval() + batch = _make_hub_inference_batch(policy) + action = policy.select_action(batch) + assert action.shape[-1] == policy.config.action_dim