Files
lerobot/tests/policies/vla_jepa/test_action_head.py
T
2026-05-13 15:55:04 +02:00

120 lines
4.1 KiB
Python

#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers")
from conftest import (
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
set_seed_all,
) # noqa: E402
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
VLAJEPAActionHead,
)
# ---------------------------------------------------------------------------
# VLAJEPAActionHead
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4), # default test dims
(7, 0, 16), # no proprioceptive state, production-like action space
(6, 8, 8), # medium dims
],
)
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
assert t.shape == (200,)
assert torch.isfinite(t).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
timesteps = torch.randint(0, 100, (2,))
state = torch.randn(2, state_dim) if state_dim > 0 else None
out_with = head._build_inputs(conditioning, actions, state, timesteps)
out_none = head._build_inputs(conditioning, actions, None, timesteps)
assert out_with.ndim == 3 and out_none.ndim == 3
if state_dim > 0:
assert out_with.shape[1] > out_none.shape[1]
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
state = torch.randn(2, state_dim) if state_dim > 0 else None
loss = head.forward(conditioning, actions, state)
assert loss.shape == ()
assert torch.isfinite(loss) and loss > 0
def test_action_head_forward_gradient_flows() -> None:
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
loss = head.forward(conditioning, actions, state)
loss.backward()
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
@torch.no_grad()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
state = torch.randn(2, state_dim) if state_dim > 0 else None
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()