diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 203460199..90b9b42b4 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -129,7 +129,9 @@ class VLAJEPAModel(nn.Module): embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id embodied_indices = embodied_mask.nonzero(as_tuple=True) - with torch.autocast("cuda", dtype=torch.bfloat16): + device_type = next(self.parameters()).device.type + + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): qwen_outputs = self.qwen.model( **qwen_inputs, output_hidden_states=True, @@ -201,7 +203,7 @@ class VLAJEPAModel(nn.Module): return {"wm_loss": wm_loss} # ---- Step 4: Action Head (same as original) ---- - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast(device_type=device_type, dtype=torch.float32): actions_tensor = torch.tensor( np.array(actions), device=last_hidden.device, dtype=torch.float32 ) # [B, T_full, action_dim] @@ -249,7 +251,9 @@ class VLAJEPAModel(nn.Module): embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id embodied_indices = embodied_mask.nonzero(as_tuple=True) - with torch.autocast("cuda", dtype=torch.bfloat16): + device_type = next(self.parameters()).device.type + + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): qwen_outputs = self.qwen.model( **qwen_inputs, output_hidden_states=True, @@ -266,7 +270,7 @@ class VLAJEPAModel(nn.Module): device=last_hidden.device, dtype=torch.float32 ) - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast(device_type=device_type, dtype=torch.float32): # Cast embodied tokens to float32 for action model compatibility pred_actions = self.action_model.predict_action( embodied_action_tokens.float(), state_tensor diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py new file mode 100644 index 000000000..0a42013b6 --- /dev/null +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import os +from copy import deepcopy +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from PIL import Image +from torch import Tensor, nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +pytestmark = pytest.mark.filterwarnings( + "ignore:In CPU autocast, but the target dtype is not supported:UserWarning" +) + + +BATCH_SIZE = 2 +ACTION_DIM = 3 +STATE_DIM = 4 +IMAGE_SIZE = 8 +ACTION_HORIZON = 4 +N_ACTION_STEPS = 2 +NUM_VIDEO_FRAMES = 3 +EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) +EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM) +PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" +PRETRAINED_SUBFOLDER = "LIBERO" + + +def set_seed_all(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +class _FakeQwenBackbone(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + self.config = SimpleNamespace( + hidden_size=hidden_size, + text_config=SimpleNamespace(hidden_size=hidden_size), + ) + + @property + def device(self) -> torch.device: + return self.weight.device + + def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: + batch_size, seq_len = input_ids.shape + hidden_size = self.config.hidden_size + values = torch.arange( + batch_size * seq_len * hidden_size, + device=input_ids.device, + dtype=torch.float32, + ).view(batch_size, seq_len, hidden_size) + hidden = values / values.numel() + self.weight + return SimpleNamespace(hidden_states=[hidden]) + + +class _FakeQwenInterface(nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.model = _FakeQwenBackbone(hidden_size=16) + + @staticmethod + def _get_torch_dtype(dtype_name: str) -> torch.dtype: + return torch.float32 if dtype_name == "float32" else torch.bfloat16 + + def expand_tokenizer(self) -> tuple[list[str], list[int], int]: + max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep + action_tokens = [ + self.config.special_action_token.format(idx) + for idx in range(max_action_tokens) + ] + action_token_ids = list(range(1000, 1000 + max_action_tokens)) + return action_tokens, action_token_ids, 2000 + + def build_inputs( + self, + images: list[list[Image.Image]], + instructions: list[str], + action_prompt: str, + embodied_prompt: str, + ) -> dict[str, Tensor]: + batch_size = len(images) + del images, instructions, action_prompt, embodied_prompt + action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep + token_ids = ( + [10] + + list(range(1000, 1000 + action_count)) + + [2000] * self.config.num_embodied_action_tokens_per_instruction + + [11] + ) + input_ids = torch.tensor( + [token_ids] * batch_size, + device=self.model.device, + dtype=torch.long, + ) + return {"input_ids": input_ids} + + @staticmethod + def tensor_to_pil(image_tensor: Tensor) -> Image.Image: + image = image_tensor.detach().cpu() + if image.ndim == 3 and image.shape[0] in (1, 3): + image = image.permute(1, 2, 0) + image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy() + return Image.fromarray(image) + + +class _FakeVideoEncoder(nn.Module): + def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size) + + @property + def device(self) -> torch.device: + return self.weight.device + + def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor: + batch_size, num_frames = pixel_values_videos.shape[:2] + hidden_size = self.config.hidden_size + frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False) + return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size) + + +class _FakeVideoProcessor: + def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]: + assert return_tensors == "pt" + return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)} + + +@pytest.fixture +def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: + from lerobot.policies.vla_jepa import modeling_vla_jepa + + monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) + monkeypatch.setattr( + modeling_vla_jepa.AutoModel, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoEncoder(), + ) + monkeypatch.setattr( + modeling_vla_jepa.AutoVideoProcessor, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoProcessor(), + ) + + +def make_config() -> VLAJEPAConfig: + config = VLAJEPAConfig( + input_features={ + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), + }, + output_features={ + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)), + }, + device="cpu", + chunk_size=ACTION_HORIZON, + n_action_steps=N_ACTION_STEPS, + future_action_window_size=ACTION_HORIZON - 1, + action_dim=ACTION_DIM, + state_dim=STATE_DIM, + num_video_frames=NUM_VIDEO_FRAMES, + num_action_tokens_per_timestep=2, + num_embodied_action_tokens_per_instruction=3, + num_inference_timesteps=2, + action_hidden_size=16, + action_num_layers=1, + action_num_heads=2, + action_attention_head_dim=8, + predictor_depth=1, + predictor_num_heads=2, + predictor_mlp_ratio=2.0, + ) + config.validate_features() + return config + + +def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, 1, STATE_DIM), + ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM), + "task": ["pick up the cube"] * batch_size, + } + + +def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, STATE_DIM), + "task": ["pick up the cube"] * batch_size, + } + + +def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.train() + + batch = make_train_batch() + batch_before = deepcopy(batch) + + loss, logs = policy.forward(batch) + + assert loss.shape == () + assert torch.isfinite(loss) + assert set(logs) == {"action_loss", "wm_loss", "loss"} + assert logs["action_loss"] > 0 + assert logs["wm_loss"] >= 0 + + loss.backward() + assert any( + param.grad is not None + for param in policy.model.action_model.parameters() + if param.requires_grad + ) + assert set(batch) == set(batch_before) + for key, value in batch.items(): + if isinstance(value, Tensor): + assert torch.equal(value, batch_before[key]) + else: + assert value == batch_before[key] + + +@torch.no_grad() +def test_vla_jepa_action_generation_shape( + patch_vla_jepa_external_models: None, +) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + action_chunk = policy.predict_action_chunk(batch) + + assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE + assert action_chunk.device.type == "cpu" + assert torch.isfinite(action_chunk).all() + + first_action = policy.select_action(batch) + second_action = policy.select_action(batch) + + assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert torch.isfinite(first_action).all() + assert torch.isfinite(second_action).all() + + +@torch.no_grad() +def test_vla_jepa_inference_reproducibility( + patch_vla_jepa_external_models: None, +) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + set_seed_all(123) + actions_1 = policy.predict_action_chunk(batch) + + set_seed_all(123) + actions_2 = policy.predict_action_chunk(batch) + + assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE + assert torch.allclose(actions_1, actions_2, atol=1e-6) + + +def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: + from huggingface_hub import hf_hub_download + from huggingface_hub.errors import LocalEntryNotFoundError + + repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID) + subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/") + checkpoint_filename = os.environ.get( + "VLA_JEPA_PRETRAINED_CHECKPOINT", + f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt", + ) + + try: + checkpoint_path = hf_hub_download( + repo_id=repo_id, + filename=checkpoint_filename, + local_files_only=True, + ) + except LocalEntryNotFoundError: + pytest.skip( + f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache." + ) + + try: + checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + state_dict = ( + checkpoint.get("state_dict") + or checkpoint.get("model_state_dict") + or checkpoint.get("model") + or checkpoint + ) + + assert isinstance(state_dict, dict) + assert len(state_dict) > 0 + assert all(isinstance(key, str) for key in list(state_dict)[:10])