update VLA-JEPA tests for arch changes and action_is_pad

- Switch conftest to use `action_model_type="DiT-test"` now that
  `action_num_heads` / `action_attention_head_dim` have been removed.
- Add action_head tests covering fully-padded loss (zero) and equivalence
  of action_is_pad=None vs all-zeros mask.
- Remove obsolete `test_native_to_lerobot_wm_only` test.
This commit is contained in:
Maximellerbach
2026-05-15 14:34:41 +02:00
parent ab5222c819
commit 1f80d91607
3 changed files with 39 additions and 9 deletions
+1 -2
View File
@@ -70,9 +70,8 @@ def make_config(
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
action_num_heads=2,
action_attention_head_dim=8,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
@@ -117,3 +117,41 @@ def test_action_head_predict_action_shape(action_dim: int, state_dim: int, actio
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()
# ---------------------------------------------------------------------------
# action_is_pad masking
# ---------------------------------------------------------------------------
def test_action_head_loss_fully_padded_is_zero() -> None:
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss = head.forward(conditioning, actions, state, action_is_pad)
assert loss.item() == 0.0
def test_action_head_loss_none_matches_no_padding() -> None:
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
set_seed_all(0)
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
set_seed_all(0)
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
assert torch.isclose(loss_none, loss_zeros)
-7
View File
@@ -257,13 +257,6 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
_, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)})
assert "action_loss" not in logs
assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5)
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# ---------------------------------------------------------------------------