diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index 2e7d047cd..a1e9f9960 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -70,9 +70,8 @@ def make_config( num_embodied_action_tokens_per_instruction=3, num_inference_timesteps=2, action_hidden_size=QWEN_HIDDEN_SIZE, + action_model_type="DiT-test", action_num_layers=1, - action_num_heads=2, - action_attention_head_dim=8, predictor_depth=1, predictor_num_heads=2, predictor_mlp_ratio=2.0, diff --git a/tests/policies/vla_jepa/test_action_head.py b/tests/policies/vla_jepa/test_action_head.py index eb2d3168d..5acff6371 100644 --- a/tests/policies/vla_jepa/test_action_head.py +++ b/tests/policies/vla_jepa/test_action_head.py @@ -117,3 +117,41 @@ def test_action_head_predict_action_shape(action_dim: int, state_dim: int, actio pred = head.predict_action(conditioning, state) assert tuple(pred.shape) == (2, action_horizon, action_dim) assert torch.isfinite(pred).all() + + +# --------------------------------------------------------------------------- +# action_is_pad masking +# --------------------------------------------------------------------------- + + +def test_action_head_loss_fully_padded_is_zero() -> None: + """Loss is 0 when every timestep is padded (exercises the clamp_min guard).""" + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + + action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool) + loss = head.forward(conditioning, actions, state, action_is_pad) + assert loss.item() == 0.0 + + +def test_action_head_loss_none_matches_no_padding() -> None: + """action_is_pad=None is equivalent to an all-False (no padding) mask.""" + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + + set_seed_all(0) + loss_none = head.forward(conditioning, actions, state, action_is_pad=None) + + set_seed_all(0) + no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool) + loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad) + + assert torch.isclose(loss_none, loss_zeros) diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index ae51126de..548fa236f 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -257,13 +257,6 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5) -def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None: - policy = VLAJEPAPolicy(make_config()) - _, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)}) - assert "action_loss" not in logs - assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5) - - # --------------------------------------------------------------------------- # Pretrained checkpoint # ---------------------------------------------------------------------------