adding more tests to ensure good coverage

This commit is contained in:
Maximellerbach
2026-05-13 15:55:04 +02:00
parent e28d34a3cf
commit ddaff399b5
5 changed files with 669 additions and 213 deletions
+202 -213
View File
@@ -4,18 +4,10 @@ from __future__ import annotations
import os
from copy import deepcopy
from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from torch import Tensor
pytest.importorskip("transformers")
pytest.importorskip("diffusers")
@@ -24,190 +16,33 @@ pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
)
from conftest import ( # noqa: E402
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
N_ACTION_STEPS,
STATE_DIM,
make_config,
make_inference_batch,
make_train_batch,
set_seed_all,
)
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO"
def set_seed_all(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ---------------------------------------------------------------------------
# Core training / inference tests
# ---------------------------------------------------------------------------
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=16)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
input_ids = torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
return {"input_ids": input_ids}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)
def make_config() -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
},
device="cpu",
chunk_size=ACTION_HORIZON,
n_action_steps=N_ACTION_STEPS,
future_action_window_size=ACTION_HORIZON - 1,
action_dim=ACTION_DIM,
state_dim=STATE_DIM,
num_video_frames=NUM_VIDEO_FRAMES,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=16,
action_num_layers=1,
action_num_heads=2,
action_attention_head_dim=8,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
)
config.validate_features()
return config
def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, STATE_DIM),
ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, STATE_DIM),
"task": ["pick up the cube"] * batch_size,
}
def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
@@ -224,9 +59,8 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
assert logs["wm_loss"] >= 0
loss.backward()
assert any(
param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad
)
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
# Batch must not be mutated.
assert set(batch) == set(batch_before)
for key, value in batch.items():
if isinstance(value, Tensor):
@@ -235,34 +69,75 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
assert value == batch_before[key]
@torch.no_grad()
def test_vla_jepa_action_generation_shape(
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
assert torch.isfinite(loss) and loss > 0
assert set(logs) == {"action_loss", "wm_loss", "loss"}
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_training_forward_various_dims(
patch_vla_jepa_external_models: None,
action_dim: int,
state_dim: int,
action_horizon: int,
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
policy = VLAJEPAPolicy(config)
policy.train()
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
loss, _ = policy.forward(batch)
assert torch.isfinite(loss) and loss > 0
@torch.no_grad()
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
action_chunk = policy.predict_action_chunk(batch)
chunk = policy.predict_action_chunk(batch)
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert chunk.device.type == "cpu"
assert torch.isfinite(chunk).all()
assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert action_chunk.device.type == "cpu"
assert torch.isfinite(action_chunk).all()
first_action = policy.select_action(batch)
second_action = policy.select_action(batch)
assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(first_action).all()
assert torch.isfinite(second_action).all()
a1 = policy.select_action(batch)
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
@torch.no_grad()
def test_vla_jepa_inference_reproducibility(
patch_vla_jepa_external_models: None,
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
def test_action_generation_various_dims(
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim)
policy = VLAJEPAPolicy(config)
policy.eval()
batch = make_inference_batch(state_dim=state_dim)
chunk = policy.predict_action_chunk(batch)
assert chunk.shape[-1] == action_dim
assert torch.isfinite(chunk).all()
@torch.no_grad()
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
@@ -270,7 +145,6 @@ def test_vla_jepa_inference_reproducibility(
set_seed_all(123)
actions_1 = policy.predict_action_chunk(batch)
set_seed_all(123)
actions_2 = policy.predict_action_chunk(batch)
@@ -278,7 +152,125 @@ def test_vla_jepa_inference_reproducibility(
assert torch.allclose(actions_1, actions_2, atol=1e-6)
def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
@torch.no_grad()
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
policy.eval()
for seed in [0, 42, 123]:
set_seed_all(seed)
chunk = policy.predict_action_chunk(make_inference_batch())
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
# ---------------------------------------------------------------------------
# Action queue behaviour
# ---------------------------------------------------------------------------
@torch.no_grad()
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
# First call fills the queue (n_action_steps items) and pops one.
a1 = policy.select_action(batch)
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
# Second call pops from the existing queue without calling predict_action_chunk.
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
@torch.no_grad()
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
policy.select_action(make_inference_batch())
assert len(policy._queues[ACTION]) > 0
policy.reset()
assert len(policy._queues[ACTION]) == 0
# ---------------------------------------------------------------------------
# Format conversion
# ---------------------------------------------------------------------------
def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None:
import numpy as np
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._lerobot_to_native(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._lerobot_to_native(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._lerobot_to_native(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch))
def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._lerobot_to_native(batch))
def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)})
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5)
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
_, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)})
assert "action_loss" not in logs
assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5)
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# ---------------------------------------------------------------------------
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import LocalEntryNotFoundError
@@ -291,12 +283,10 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
try:
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=checkpoint_filename,
local_files_only=True,
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
)
except LocalEntryNotFoundError:
pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.")
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
@@ -309,7 +299,6 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
or checkpoint.get("model")
or checkpoint
)
assert isinstance(state_dict, dict)
assert len(state_dict) > 0
assert all(isinstance(key, str) for key in list(state_dict)[:10])
assert all(isinstance(k, str) for k in list(state_dict)[:10])