adding more tests to ensure good coverage

This commit is contained in:
Maximellerbach
2026-05-13 15:55:04 +02:00
parent e28d34a3cf
commit ddaff399b5
5 changed files with 669 additions and 213 deletions
+232
View File
@@ -0,0 +1,232 @@
#!/usr/bin/env python
"""Shared fixtures and helpers for VLA-JEPA tests."""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# ---------------------------------------------------------------------------
# Shared constants
# ---------------------------------------------------------------------------
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def set_seed_all(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_config(
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
device="cpu",
chunk_size=action_horizon,
n_action_steps=min(N_ACTION_STEPS, action_horizon),
future_action_window_size=action_horizon - 1,
action_dim=action_dim,
state_dim=state_dim,
num_video_frames=num_video_frames,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_num_layers=1,
action_num_heads=2,
action_attention_head_dim=8,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=1,
)
config.validate_features()
return config
def make_train_batch(
batch_size: int = BATCH_SIZE,
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, state_dim),
ACTION: torch.randn(batch_size, action_horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(
batch_size: int = BATCH_SIZE,
state_dim: int = STATE_DIM,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, state_dim),
"task": ["pick up the cube"] * batch_size,
}
# ---------------------------------------------------------------------------
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
# ---------------------------------------------------------------------------
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
return {
"input_ids": torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)
+119
View File
@@ -0,0 +1,119 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers")
from conftest import (
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
set_seed_all,
) # noqa: E402
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
VLAJEPAActionHead,
)
# ---------------------------------------------------------------------------
# VLAJEPAActionHead
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4), # default test dims
(7, 0, 16), # no proprioceptive state, production-like action space
(6, 8, 8), # medium dims
],
)
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
assert t.shape == (200,)
assert torch.isfinite(t).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
timesteps = torch.randint(0, 100, (2,))
state = torch.randn(2, state_dim) if state_dim > 0 else None
out_with = head._build_inputs(conditioning, actions, state, timesteps)
out_none = head._build_inputs(conditioning, actions, None, timesteps)
assert out_with.ndim == 3 and out_none.ndim == 3
if state_dim > 0:
assert out_with.shape[1] > out_none.shape[1]
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
state = torch.randn(2, state_dim) if state_dim > 0 else None
loss = head.forward(conditioning, actions, state)
assert loss.shape == ()
assert torch.isfinite(loss) and loss > 0
def test_action_head_forward_gradient_flows() -> None:
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
loss = head.forward(conditioning, actions, state)
loss.backward()
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
@torch.no_grad()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
state = torch.randn(2, state_dim) if state_dim > 0 else None
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()
@@ -0,0 +1,63 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def test_delta_indices() -> None:
config = make_config()
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
assert config.action_delta_indices == list(range(ACTION_HORIZON))
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="n_action_steps"):
VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3)
def test_future_window_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="predicted action horizon"):
VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4)
def test_too_few_video_frames_raises() -> None:
with pytest.raises(ValueError, match="video_horizon"):
VLAJEPAConfig(
chunk_size=16,
n_action_steps=16,
future_action_window_size=15,
num_video_frames=2,
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
)
def test_validate_features_no_image_raises() -> None:
config = VLAJEPAConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
with pytest.raises(ValueError, match="at least one visual input feature"):
config.validate_features()
def test_validate_features_no_action_raises() -> None:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
},
output_features={},
)
with pytest.raises(ValueError, match="action output feature"):
config.validate_features()
def test_validate_features_sets_action_dim_from_feature() -> None:
config = make_config(action_dim=6, state_dim=10)
assert config.action_dim == 6
assert config.state_dim == 10
+202 -213
View File
@@ -4,18 +4,10 @@ from __future__ import annotations
import os import os
from copy import deepcopy from copy import deepcopy
from types import SimpleNamespace
import numpy as np
import pytest import pytest
import torch import torch
from PIL import Image from torch import Tensor
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
pytest.importorskip("transformers") pytest.importorskip("transformers")
pytest.importorskip("diffusers") pytest.importorskip("diffusers")
@@ -24,190 +16,33 @@ pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning" "ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
) )
from conftest import ( # noqa: E402
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
N_ACTION_STEPS,
STATE_DIM,
make_config,
make_inference_batch,
make_train_batch,
set_seed_all,
)
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO" PRETRAINED_SUBFOLDER = "LIBERO"
def set_seed_all(seed: int) -> None: # ---------------------------------------------------------------------------
np.random.seed(seed) # Core training / inference tests
torch.manual_seed(seed) # ---------------------------------------------------------------------------
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class _FakeQwenBackbone(nn.Module): def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=16)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
input_ids = torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
return {"input_ids": input_ids}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)
def make_config() -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
device="cpu",
chunk_size=ACTION_HORIZON,
n_action_steps=N_ACTION_STEPS,
future_action_window_size=ACTION_HORIZON - 1,
action_dim=ACTION_DIM,
state_dim=STATE_DIM,
num_video_frames=NUM_VIDEO_FRAMES,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=16,
action_num_layers=1,
action_num_heads=2,
action_attention_head_dim=8,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
)
config.validate_features()
return config
def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, STATE_DIM),
ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, STATE_DIM),
"task": ["pick up the cube"] * batch_size,
}
def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42) set_seed_all(42)
policy = VLAJEPAPolicy(make_config()) policy = VLAJEPAPolicy(make_config())
policy.train() policy.train()
@@ -224,9 +59,8 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
assert logs["wm_loss"] >= 0 assert logs["wm_loss"] >= 0
loss.backward() loss.backward()
assert any( assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad # Batch must not be mutated.
)
assert set(batch) == set(batch_before) assert set(batch) == set(batch_before)
for key, value in batch.items(): for key, value in batch.items():
if isinstance(value, Tensor): if isinstance(value, Tensor):
@@ -235,34 +69,75 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
assert value == batch_before[key] assert value == batch_before[key]
@torch.no_grad() @pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_vla_jepa_action_generation_shape( def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
assert torch.isfinite(loss) and loss > 0
assert set(logs) == {"action_loss", "wm_loss", "loss"}
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_training_forward_various_dims(
patch_vla_jepa_external_models: None, patch_vla_jepa_external_models: None,
action_dim: int,
state_dim: int,
action_horizon: int,
) -> None: ) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
policy = VLAJEPAPolicy(config)
policy.train()
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
loss, _ = policy.forward(batch)
assert torch.isfinite(loss) and loss > 0
@torch.no_grad()
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42) set_seed_all(42)
policy = VLAJEPAPolicy(make_config()) policy = VLAJEPAPolicy(make_config())
policy.eval() policy.eval()
batch = make_inference_batch() batch = make_inference_batch()
action_chunk = policy.predict_action_chunk(batch) chunk = policy.predict_action_chunk(batch)
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert chunk.device.type == "cpu"
assert torch.isfinite(chunk).all()
assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE a1 = policy.select_action(batch)
assert action_chunk.device.type == "cpu" a2 = policy.select_action(batch)
assert torch.isfinite(action_chunk).all() assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
first_action = policy.select_action(batch) assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
second_action = policy.select_action(batch)
assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(first_action).all()
assert torch.isfinite(second_action).all()
@torch.no_grad() @torch.no_grad()
def test_vla_jepa_inference_reproducibility( @pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
patch_vla_jepa_external_models: None, def test_action_generation_various_dims(
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
) -> None: ) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim)
policy = VLAJEPAPolicy(config)
policy.eval()
batch = make_inference_batch(state_dim=state_dim)
chunk = policy.predict_action_chunk(batch)
assert chunk.shape[-1] == action_dim
assert torch.isfinite(chunk).all()
@torch.no_grad()
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42) set_seed_all(42)
policy = VLAJEPAPolicy(make_config()) policy = VLAJEPAPolicy(make_config())
policy.eval() policy.eval()
@@ -270,7 +145,6 @@ def test_vla_jepa_inference_reproducibility(
set_seed_all(123) set_seed_all(123)
actions_1 = policy.predict_action_chunk(batch) actions_1 = policy.predict_action_chunk(batch)
set_seed_all(123) set_seed_all(123)
actions_2 = policy.predict_action_chunk(batch) actions_2 = policy.predict_action_chunk(batch)
@@ -278,7 +152,125 @@ def test_vla_jepa_inference_reproducibility(
assert torch.allclose(actions_1, actions_2, atol=1e-6) assert torch.allclose(actions_1, actions_2, atol=1e-6)
def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: @torch.no_grad()
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
policy.eval()
for seed in [0, 42, 123]:
set_seed_all(seed)
chunk = policy.predict_action_chunk(make_inference_batch())
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
# ---------------------------------------------------------------------------
# Action queue behaviour
# ---------------------------------------------------------------------------
@torch.no_grad()
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
# First call fills the queue (n_action_steps items) and pops one.
a1 = policy.select_action(batch)
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
# Second call pops from the existing queue without calling predict_action_chunk.
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
@torch.no_grad()
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
policy.select_action(make_inference_batch())
assert len(policy._queues[ACTION]) > 0
policy.reset()
assert len(policy._queues[ACTION]) == 0
# ---------------------------------------------------------------------------
# Format conversion
# ---------------------------------------------------------------------------
def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None:
import numpy as np
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._lerobot_to_native(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._lerobot_to_native(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._lerobot_to_native(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch))
def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._lerobot_to_native(batch))
def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)})
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5)
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
_, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)})
assert "action_loss" not in logs
assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5)
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# ---------------------------------------------------------------------------
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.errors import LocalEntryNotFoundError from huggingface_hub.errors import LocalEntryNotFoundError
@@ -291,12 +283,10 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
try: try:
checkpoint_path = hf_hub_download( checkpoint_path = hf_hub_download(
repo_id=repo_id, repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
filename=checkpoint_filename,
local_files_only=True,
) )
except LocalEntryNotFoundError: except LocalEntryNotFoundError:
pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.") pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
try: try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False) checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
@@ -309,7 +299,6 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
or checkpoint.get("model") or checkpoint.get("model")
or checkpoint or checkpoint
) )
assert isinstance(state_dict, dict) assert isinstance(state_dict, dict)
assert len(state_dict) > 0 assert len(state_dict) > 0
assert all(isinstance(key, str) for key in list(state_dict)[:10]) assert all(isinstance(k, str) for k in list(state_dict)[:10])
@@ -0,0 +1,53 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
from lerobot.policies.vla_jepa.world_model import (
ActionConditionedVideoPredictor,
)
_ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 16,
num_action_tokens: int = 2,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
depth=1,
num_heads=2,
mlp_ratio=2.0,
num_action_tokens_per_step=num_action_tokens,
)
@pytest.mark.parametrize(
"batch,num_steps,tokens_per_frame,embed_dim",
[
(1, 2, 1, 8),
(2, 3, 4, 8),
(4, 5, 2, 16),
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM)
frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor()
frame_tokens = torch.randn(2, 3, 4, 8)
with pytest.raises(ValueError, match="Expected 3 action steps"):
predictor(frame_tokens, torch.randn(2, 2, 2, 8))