From ddaff399b58cfc24f9286ceafe9d3540b3612d24 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Wed, 13 May 2026 15:55:04 +0200 Subject: [PATCH] adding more tests to ensure good coverage --- tests/policies/vla_jepa/conftest.py | 232 ++++++++++ tests/policies/vla_jepa/test_action_head.py | 119 +++++ tests/policies/vla_jepa/test_configuration.py | 63 +++ tests/policies/vla_jepa/test_vla_jepa.py | 415 +++++++++--------- tests/policies/vla_jepa/test_world_model.py | 53 +++ 5 files changed, 669 insertions(+), 213 deletions(-) create mode 100644 tests/policies/vla_jepa/conftest.py create mode 100644 tests/policies/vla_jepa/test_action_head.py create mode 100644 tests/policies/vla_jepa/test_configuration.py create mode 100644 tests/policies/vla_jepa/test_world_model.py diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py new file mode 100644 index 000000000..2e7d047cd --- /dev/null +++ b/tests/policies/vla_jepa/conftest.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python +"""Shared fixtures and helpers for VLA-JEPA tests.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from PIL import Image +from torch import Tensor, nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +# --------------------------------------------------------------------------- +# Shared constants +# --------------------------------------------------------------------------- + +BATCH_SIZE = 2 +ACTION_DIM = 3 +STATE_DIM = 4 +IMAGE_SIZE = 8 +ACTION_HORIZON = 4 +N_ACTION_STEPS = 2 +NUM_VIDEO_FRAMES = 3 +QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone + +EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) +EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def set_seed_all(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def make_config( + action_dim: int = ACTION_DIM, + state_dim: int = STATE_DIM, + action_horizon: int = ACTION_HORIZON, + num_video_frames: int = NUM_VIDEO_FRAMES, +) -> VLAJEPAConfig: + config = VLAJEPAConfig( + input_features={ + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + }, + output_features={ + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)), + }, + device="cpu", + chunk_size=action_horizon, + n_action_steps=min(N_ACTION_STEPS, action_horizon), + future_action_window_size=action_horizon - 1, + action_dim=action_dim, + state_dim=state_dim, + num_video_frames=num_video_frames, + num_action_tokens_per_timestep=2, + num_embodied_action_tokens_per_instruction=3, + num_inference_timesteps=2, + action_hidden_size=QWEN_HIDDEN_SIZE, + action_num_layers=1, + action_num_heads=2, + action_attention_head_dim=8, + predictor_depth=1, + predictor_num_heads=2, + predictor_mlp_ratio=2.0, + jepa_tubelet_size=1, + ) + config.validate_features() + return config + + +def make_train_batch( + batch_size: int = BATCH_SIZE, + action_dim: int = ACTION_DIM, + state_dim: int = STATE_DIM, + action_horizon: int = ACTION_HORIZON, + num_video_frames: int = NUM_VIDEO_FRAMES, +) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, 1, state_dim), + ACTION: torch.randn(batch_size, action_horizon, action_dim), + "task": ["pick up the cube"] * batch_size, + } + + +def make_inference_batch( + batch_size: int = BATCH_SIZE, + state_dim: int = STATE_DIM, +) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, state_dim), + "task": ["pick up the cube"] * batch_size, + } + + +# --------------------------------------------------------------------------- +# Fake external models (replace Qwen3-VL and V-JEPA at test time) +# --------------------------------------------------------------------------- + + +class _FakeQwenBackbone(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + self.config = SimpleNamespace( + hidden_size=hidden_size, + text_config=SimpleNamespace(hidden_size=hidden_size), + ) + + @property + def device(self) -> torch.device: + return self.weight.device + + def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: + batch_size, seq_len = input_ids.shape + hidden_size = self.config.hidden_size + values = torch.arange( + batch_size * seq_len * hidden_size, + device=input_ids.device, + dtype=torch.float32, + ).view(batch_size, seq_len, hidden_size) + hidden = values / values.numel() + self.weight + return SimpleNamespace(hidden_states=[hidden]) + + +class _FakeQwenInterface(nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE) + + @staticmethod + def _get_torch_dtype(dtype_name: str) -> torch.dtype: + return torch.float32 if dtype_name == "float32" else torch.bfloat16 + + def expand_tokenizer(self) -> tuple[list[str], list[int], int]: + max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep + action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)] + action_token_ids = list(range(1000, 1000 + max_action_tokens)) + return action_tokens, action_token_ids, 2000 + + def build_inputs( + self, + images: list[list[Image.Image]], + instructions: list[str], + action_prompt: str, + embodied_prompt: str, + ) -> dict[str, Tensor]: + batch_size = len(images) + del images, instructions, action_prompt, embodied_prompt + action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep + token_ids = ( + [10] + + list(range(1000, 1000 + action_count)) + + [2000] * self.config.num_embodied_action_tokens_per_instruction + + [11] + ) + return { + "input_ids": torch.tensor( + [token_ids] * batch_size, + device=self.model.device, + dtype=torch.long, + ) + } + + @staticmethod + def tensor_to_pil(image_tensor: Tensor) -> Image.Image: + image = image_tensor.detach().cpu() + if image.ndim == 3 and image.shape[0] in (1, 3): + image = image.permute(1, 2, 0) + image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy() + return Image.fromarray(image) + + +class _FakeVideoEncoder(nn.Module): + def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size) + + @property + def device(self) -> torch.device: + return self.weight.device + + def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor: + batch_size, num_frames = pixel_values_videos.shape[:2] + hidden_size = self.config.hidden_size + frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False) + return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size) + + +class _FakeVideoProcessor: + def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]: + assert return_tensors == "pt" + return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: + from lerobot.policies.vla_jepa import modeling_vla_jepa + + monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) + monkeypatch.setattr( + modeling_vla_jepa.AutoModel, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoEncoder(), + ) + monkeypatch.setattr( + modeling_vla_jepa.AutoVideoProcessor, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoProcessor(), + ) diff --git a/tests/policies/vla_jepa/test_action_head.py b/tests/policies/vla_jepa/test_action_head.py new file mode 100644 index 000000000..eb2d3168d --- /dev/null +++ b/tests/policies/vla_jepa/test_action_head.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +import torch + +pytest.importorskip("diffusers") + +from conftest import ( + ACTION_DIM, + ACTION_HORIZON, + BATCH_SIZE, + QWEN_HIDDEN_SIZE, + STATE_DIM, + make_config, + set_seed_all, +) # noqa: E402 + +from lerobot.policies.vla_jepa.action_head import ( # noqa: E402 + VLAJEPAActionHead, +) + +# --------------------------------------------------------------------------- +# VLAJEPAActionHead +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), # default test dims + (7, 0, 16), # no proprioceptive state, production-like action space + (6, 8, 8), # medium dims + ], +) +def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None: + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32) + assert t.shape == (200,) + assert torch.isfinite(t).all() + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None: + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(2, action_horizon, action_dim) + timesteps = torch.randint(0, 100, (2,)) + + state = torch.randn(2, state_dim) if state_dim > 0 else None + out_with = head._build_inputs(conditioning, actions, state, timesteps) + out_none = head._build_inputs(conditioning, actions, None, timesteps) + + assert out_with.ndim == 3 and out_none.ndim == 3 + if state_dim > 0: + assert out_with.shape[1] > out_none.shape[1] + assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all() + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(2, action_horizon, action_dim) + state = torch.randn(2, state_dim) if state_dim > 0 else None + loss = head.forward(conditioning, actions, state) + assert loss.shape == () + assert torch.isfinite(loss) and loss > 0 + + +def test_action_head_forward_gradient_flows() -> None: + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + loss = head.forward(conditioning, actions, state) + loss.backward() + assert any(p.grad is not None for p in head.parameters() if p.requires_grad) + + +@torch.no_grad() +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + state = torch.randn(2, state_dim) if state_dim > 0 else None + pred = head.predict_action(conditioning, state) + assert tuple(pred.shape) == (2, action_horizon, action_dim) + assert torch.isfinite(pred).all() diff --git a/tests/policies/vla_jepa/test_configuration.py b/tests/policies/vla_jepa/test_configuration.py new file mode 100644 index 000000000..34e9bcff8 --- /dev/null +++ b/tests/policies/vla_jepa/test_configuration.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +def test_delta_indices() -> None: + config = make_config() + assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES)) + assert config.action_delta_indices == list(range(ACTION_HORIZON)) + + +def test_n_action_steps_exceeds_chunk_size_raises() -> None: + with pytest.raises(ValueError, match="n_action_steps"): + VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3) + + +def test_future_window_exceeds_chunk_size_raises() -> None: + with pytest.raises(ValueError, match="predicted action horizon"): + VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4) + + +def test_too_few_video_frames_raises() -> None: + with pytest.raises(ValueError, match="video_horizon"): + VLAJEPAConfig( + chunk_size=16, + n_action_steps=16, + future_action_window_size=15, + num_video_frames=2, + jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0 + ) + + +def test_validate_features_no_image_raises() -> None: + config = VLAJEPAConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}, + ) + with pytest.raises(ValueError, match="at least one visual input feature"): + config.validate_features() + + +def test_validate_features_no_action_raises() -> None: + config = VLAJEPAConfig( + input_features={ + f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + }, + output_features={}, + ) + with pytest.raises(ValueError, match="action output feature"): + config.validate_features() + + +def test_validate_features_sets_action_dim_from_feature() -> None: + config = make_config(action_dim=6, state_dim=10) + assert config.action_dim == 6 + assert config.state_dim == 10 diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index ffec4c201..ae51126de 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -4,18 +4,10 @@ from __future__ import annotations import os from copy import deepcopy -from types import SimpleNamespace -import numpy as np import pytest import torch -from PIL import Image -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig -from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from torch import Tensor pytest.importorskip("transformers") pytest.importorskip("diffusers") @@ -24,190 +16,33 @@ pytestmark = pytest.mark.filterwarnings( "ignore:In CPU autocast, but the target dtype is not supported:UserWarning" ) +from conftest import ( # noqa: E402 + ACTION_DIM, + ACTION_HORIZON, + BATCH_SIZE, + EXPECTED_ACTION_CHUNK_SHAPE, + EXPECTED_SELECT_ACTION_SHAPE, + N_ACTION_STEPS, + STATE_DIM, + make_config, + make_inference_batch, + make_train_batch, + set_seed_all, +) + +from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402 +from lerobot.utils.constants import ACTION # noqa: E402 -BATCH_SIZE = 2 -ACTION_DIM = 3 -STATE_DIM = 4 -IMAGE_SIZE = 8 -ACTION_HORIZON = 4 -N_ACTION_STEPS = 2 -NUM_VIDEO_FRAMES = 3 -EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) -EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM) PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" PRETRAINED_SUBFOLDER = "LIBERO" -def set_seed_all(seed: int) -> None: - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) +# --------------------------------------------------------------------------- +# Core training / inference tests +# --------------------------------------------------------------------------- -class _FakeQwenBackbone(nn.Module): - def __init__(self, hidden_size: int) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(1)) - self.config = SimpleNamespace( - hidden_size=hidden_size, - text_config=SimpleNamespace(hidden_size=hidden_size), - ) - - @property - def device(self) -> torch.device: - return self.weight.device - - def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: - batch_size, seq_len = input_ids.shape - hidden_size = self.config.hidden_size - values = torch.arange( - batch_size * seq_len * hidden_size, - device=input_ids.device, - dtype=torch.float32, - ).view(batch_size, seq_len, hidden_size) - hidden = values / values.numel() + self.weight - return SimpleNamespace(hidden_states=[hidden]) - - -class _FakeQwenInterface(nn.Module): - def __init__(self, config: VLAJEPAConfig) -> None: - super().__init__() - self.config = config - self.model = _FakeQwenBackbone(hidden_size=16) - - @staticmethod - def _get_torch_dtype(dtype_name: str) -> torch.dtype: - return torch.float32 if dtype_name == "float32" else torch.bfloat16 - - def expand_tokenizer(self) -> tuple[list[str], list[int], int]: - max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep - action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)] - action_token_ids = list(range(1000, 1000 + max_action_tokens)) - return action_tokens, action_token_ids, 2000 - - def build_inputs( - self, - images: list[list[Image.Image]], - instructions: list[str], - action_prompt: str, - embodied_prompt: str, - ) -> dict[str, Tensor]: - batch_size = len(images) - del images, instructions, action_prompt, embodied_prompt - action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep - token_ids = ( - [10] - + list(range(1000, 1000 + action_count)) - + [2000] * self.config.num_embodied_action_tokens_per_instruction - + [11] - ) - input_ids = torch.tensor( - [token_ids] * batch_size, - device=self.model.device, - dtype=torch.long, - ) - return {"input_ids": input_ids} - - @staticmethod - def tensor_to_pil(image_tensor: Tensor) -> Image.Image: - image = image_tensor.detach().cpu() - if image.ndim == 3 and image.shape[0] in (1, 3): - image = image.permute(1, 2, 0) - image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy() - return Image.fromarray(image) - - -class _FakeVideoEncoder(nn.Module): - def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None: - super().__init__() - self.weight = nn.Parameter(torch.ones(1)) - self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size) - - @property - def device(self) -> torch.device: - return self.weight.device - - def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor: - batch_size, num_frames = pixel_values_videos.shape[:2] - hidden_size = self.config.hidden_size - frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False) - return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size) - - -class _FakeVideoProcessor: - def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]: - assert return_tensors == "pt" - return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)} - - -@pytest.fixture -def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: - from lerobot.policies.vla_jepa import modeling_vla_jepa - - monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) - monkeypatch.setattr( - modeling_vla_jepa.AutoModel, - "from_pretrained", - lambda *args, **kwargs: _FakeVideoEncoder(), - ) - monkeypatch.setattr( - modeling_vla_jepa.AutoVideoProcessor, - "from_pretrained", - lambda *args, **kwargs: _FakeVideoProcessor(), - ) - - -def make_config() -> VLAJEPAConfig: - config = VLAJEPAConfig( - input_features={ - f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), - OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), - }, - output_features={ - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)), - }, - device="cpu", - chunk_size=ACTION_HORIZON, - n_action_steps=N_ACTION_STEPS, - future_action_window_size=ACTION_HORIZON - 1, - action_dim=ACTION_DIM, - state_dim=STATE_DIM, - num_video_frames=NUM_VIDEO_FRAMES, - num_action_tokens_per_timestep=2, - num_embodied_action_tokens_per_instruction=3, - num_inference_timesteps=2, - action_hidden_size=16, - action_num_layers=1, - action_num_heads=2, - action_attention_head_dim=8, - predictor_depth=1, - predictor_num_heads=2, - predictor_mlp_ratio=2.0, - ) - config.validate_features() - return config - - -def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]: - return { - f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE), - OBS_STATE: torch.randn(batch_size, 1, STATE_DIM), - ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM), - "task": ["pick up the cube"] * batch_size, - } - - -def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]: - return { - f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), - OBS_STATE: torch.randn(batch_size, STATE_DIM), - "task": ["pick up the cube"] * batch_size, - } - - -def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None: +def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None: set_seed_all(42) policy = VLAJEPAPolicy(make_config()) policy.train() @@ -224,9 +59,8 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> assert logs["wm_loss"] >= 0 loss.backward() - assert any( - param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad - ) + assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad) + # Batch must not be mutated. assert set(batch) == set(batch_before) for key, value in batch.items(): if isinstance(value, Tensor): @@ -235,34 +69,75 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> assert value == batch_before[key] -@torch.no_grad() -def test_vla_jepa_action_generation_shape( +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.train() + loss, logs = policy.forward(make_train_batch(batch_size=batch_size)) + assert torch.isfinite(loss) and loss > 0 + assert set(logs) == {"action_loss", "wm_loss", "loss"} + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_training_forward_various_dims( patch_vla_jepa_external_models: None, + action_dim: int, + state_dim: int, + action_horizon: int, ) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + policy = VLAJEPAPolicy(config) + policy.train() + batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + loss, _ = policy.forward(batch) + assert torch.isfinite(loss) and loss > 0 + + +@torch.no_grad() +def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None: set_seed_all(42) policy = VLAJEPAPolicy(make_config()) policy.eval() batch = make_inference_batch() - action_chunk = policy.predict_action_chunk(batch) + chunk = policy.predict_action_chunk(batch) + assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE + assert chunk.device.type == "cpu" + assert torch.isfinite(chunk).all() - assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE - assert action_chunk.device.type == "cpu" - assert torch.isfinite(action_chunk).all() - - first_action = policy.select_action(batch) - second_action = policy.select_action(batch) - - assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE - assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE - assert torch.isfinite(first_action).all() - assert torch.isfinite(second_action).all() + a1 = policy.select_action(batch) + a2 = policy.select_action(batch) + assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert torch.isfinite(a1).all() and torch.isfinite(a2).all() @torch.no_grad() -def test_vla_jepa_inference_reproducibility( - patch_vla_jepa_external_models: None, +@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)]) +def test_action_generation_various_dims( + patch_vla_jepa_external_models: None, action_dim: int, state_dim: int ) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim) + policy = VLAJEPAPolicy(config) + policy.eval() + batch = make_inference_batch(state_dim=state_dim) + chunk = policy.predict_action_chunk(batch) + assert chunk.shape[-1] == action_dim + assert torch.isfinite(chunk).all() + + +@torch.no_grad() +def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None: set_seed_all(42) policy = VLAJEPAPolicy(make_config()) policy.eval() @@ -270,7 +145,6 @@ def test_vla_jepa_inference_reproducibility( set_seed_all(123) actions_1 = policy.predict_action_chunk(batch) - set_seed_all(123) actions_2 = policy.predict_action_chunk(batch) @@ -278,7 +152,125 @@ def test_vla_jepa_inference_reproducibility( assert torch.allclose(actions_1, actions_2, atol=1e-6) -def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: +@torch.no_grad() +def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + policy.eval() + for seed in [0, 42, 123]: + set_seed_all(seed) + chunk = policy.predict_action_chunk(make_inference_batch()) + assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}" + + +# --------------------------------------------------------------------------- +# Action queue behaviour +# --------------------------------------------------------------------------- + + +@torch.no_grad() +def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + # First call fills the queue (n_action_steps items) and pops one. + a1 = policy.select_action(batch) + assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1 + + # Second call pops from the existing queue without calling predict_action_chunk. + a2 = policy.select_action(batch) + assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE + + +@torch.no_grad() +def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + policy.select_action(make_inference_batch()) + assert len(policy._queues[ACTION]) > 0 + + policy.reset() + assert len(policy._queues[ACTION]) == 0 + + +# --------------------------------------------------------------------------- +# Format conversion +# --------------------------------------------------------------------------- + + +def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None: + import numpy as np + from PIL import Image + + policy = VLAJEPAPolicy(make_config()) + examples = policy._lerobot_to_native(make_train_batch()) + + assert len(examples) == BATCH_SIZE + for ex in examples: + assert set(ex) >= {"image", "video", "lang", "action", "state"} + assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image) + assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C] + assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM) + assert ex["state"].shape == (1, STATE_DIM) + + +def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + for ex in policy._lerobot_to_native(make_inference_batch()): + assert "action" not in ex + assert "image" in ex and "video" in ex and "lang" in ex + + +def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + del batch["task"] + examples = policy._lerobot_to_native(batch) + assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples) + + +def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + batch["task"] = "open the drawer" + assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch)) + + +def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None: + from lerobot.utils.constants import OBS_STATE + + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + del batch[OBS_STATE] + assert all("state" not in ex for ex in policy._lerobot_to_native(batch)) + + +def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)}) + assert torch.isfinite(loss) + assert set(logs) == {"action_loss", "wm_loss", "loss"} + assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5) + assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5) + + +def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + _, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)}) + assert "action_loss" not in logs + assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Pretrained checkpoint +# --------------------------------------------------------------------------- + + +def test_pretrained_checkpoint_loads_from_hf_cache() -> None: + import torch from huggingface_hub import hf_hub_download from huggingface_hub.errors import LocalEntryNotFoundError @@ -291,12 +283,10 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: try: checkpoint_path = hf_hub_download( - repo_id=repo_id, - filename=checkpoint_filename, - local_files_only=True, + repo_id=repo_id, filename=checkpoint_filename, local_files_only=True ) except LocalEntryNotFoundError: - pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.") + pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.") try: checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False) @@ -309,7 +299,6 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: or checkpoint.get("model") or checkpoint ) - assert isinstance(state_dict, dict) assert len(state_dict) > 0 - assert all(isinstance(key, str) for key in list(state_dict)[:10]) + assert all(isinstance(k, str) for k in list(state_dict)[:10]) diff --git a/tests/policies/vla_jepa/test_world_model.py b/tests/policies/vla_jepa/test_world_model.py new file mode 100644 index 000000000..0077efb3b --- /dev/null +++ b/tests/policies/vla_jepa/test_world_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +import torch + +from lerobot.policies.vla_jepa.world_model import ( + ActionConditionedVideoPredictor, +) + +_ACTION_EMBED_DIM = 8 + + +def _make_predictor( + embed_dim: int = 8, + action_embed_dim: int = _ACTION_EMBED_DIM, + predictor_embed_dim: int = 16, + num_action_tokens: int = 2, +) -> ActionConditionedVideoPredictor: + return ActionConditionedVideoPredictor( + embed_dim=embed_dim, + action_embed_dim=action_embed_dim, + predictor_embed_dim=predictor_embed_dim, + depth=1, + num_heads=2, + mlp_ratio=2.0, + num_action_tokens_per_step=num_action_tokens, + ) + + +@pytest.mark.parametrize( + "batch,num_steps,tokens_per_frame,embed_dim", + [ + (1, 2, 1, 8), + (2, 3, 4, 8), + (4, 5, 2, 16), + ], +) +def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None: + predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM) + frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim) + action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM) + out = predictor(frame_tokens, action_tokens) + assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim) + assert torch.isfinite(out).all() + + +def test_predictor_step_mismatch_raises() -> None: + predictor = _make_predictor() + frame_tokens = torch.randn(2, 3, 4, 8) + with pytest.raises(ValueError, match="Expected 3 action steps"): + predictor(frame_tokens, torch.randn(2, 2, 2, 8))