mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
update VLA-JEPA tests for arch changes and action_is_pad
- Switch conftest to use `action_model_type="DiT-test"` now that `action_num_heads` / `action_attention_head_dim` have been removed. - Add action_head tests covering fully-padded loss (zero) and equivalence of action_is_pad=None vs all-zeros mask. - Remove obsolete `test_native_to_lerobot_wm_only` test.
This commit is contained in:
@@ -70,9 +70,8 @@ def make_config(
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
action_num_heads=2,
|
||||
action_attention_head_dim=8,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
|
||||
@@ -117,3 +117,41 @@ def test_action_head_predict_action_shape(action_dim: int, state_dim: int, actio
|
||||
pred = head.predict_action(conditioning, state)
|
||||
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||
assert torch.isfinite(pred).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# action_is_pad masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
set_seed_all(0)
|
||||
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||
|
||||
set_seed_all(0)
|
||||
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||
|
||||
assert torch.isclose(loss_none, loss_zeros)
|
||||
|
||||
@@ -257,13 +257,6 @@ def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) ->
|
||||
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
|
||||
|
||||
|
||||
def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
_, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)})
|
||||
assert "action_loss" not in logs
|
||||
assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user