mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
adding more tests to ensure good coverage
This commit is contained in:
@@ -0,0 +1,232 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
ACTION_DIM = 3
|
||||||
|
STATE_DIM = 4
|
||||||
|
IMAGE_SIZE = 8
|
||||||
|
ACTION_HORIZON = 4
|
||||||
|
N_ACTION_STEPS = 2
|
||||||
|
NUM_VIDEO_FRAMES = 3
|
||||||
|
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||||
|
|
||||||
|
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_all(seed: int) -> None:
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def make_config(
|
||||||
|
action_dim: int = ACTION_DIM,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
action_horizon: int = ACTION_HORIZON,
|
||||||
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||||
|
) -> VLAJEPAConfig:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={
|
||||||
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||||
|
},
|
||||||
|
output_features={
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||||
|
},
|
||||||
|
device="cpu",
|
||||||
|
chunk_size=action_horizon,
|
||||||
|
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||||
|
future_action_window_size=action_horizon - 1,
|
||||||
|
action_dim=action_dim,
|
||||||
|
state_dim=state_dim,
|
||||||
|
num_video_frames=num_video_frames,
|
||||||
|
num_action_tokens_per_timestep=2,
|
||||||
|
num_embodied_action_tokens_per_instruction=3,
|
||||||
|
num_inference_timesteps=2,
|
||||||
|
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||||
|
action_num_layers=1,
|
||||||
|
action_num_heads=2,
|
||||||
|
action_attention_head_dim=8,
|
||||||
|
predictor_depth=1,
|
||||||
|
predictor_num_heads=2,
|
||||||
|
predictor_mlp_ratio=2.0,
|
||||||
|
jepa_tubelet_size=1,
|
||||||
|
)
|
||||||
|
config.validate_features()
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_batch(
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
action_dim: int = ACTION_DIM,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
action_horizon: int = ACTION_HORIZON,
|
||||||
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||||
|
) -> dict[str, Tensor | list[str]]:
|
||||||
|
return {
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||||
|
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||||
|
"task": ["pick up the cube"] * batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_inference_batch(
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
) -> dict[str, Tensor | list[str]]:
|
||||||
|
return {
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||||
|
"task": ["pick up the cube"] * batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQwenBackbone(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(1))
|
||||||
|
self.config = SimpleNamespace(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
values = torch.arange(
|
||||||
|
batch_size * seq_len * hidden_size,
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).view(batch_size, seq_len, hidden_size)
|
||||||
|
hidden = values / values.numel() + self.weight
|
||||||
|
return SimpleNamespace(hidden_states=[hidden])
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQwenInterface(nn.Module):
|
||||||
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||||
|
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||||
|
|
||||||
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||||
|
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||||
|
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||||
|
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||||
|
return action_tokens, action_token_ids, 2000
|
||||||
|
|
||||||
|
def build_inputs(
|
||||||
|
self,
|
||||||
|
images: list[list[Image.Image]],
|
||||||
|
instructions: list[str],
|
||||||
|
action_prompt: str,
|
||||||
|
embodied_prompt: str,
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
batch_size = len(images)
|
||||||
|
del images, instructions, action_prompt, embodied_prompt
|
||||||
|
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||||
|
token_ids = (
|
||||||
|
[10]
|
||||||
|
+ list(range(1000, 1000 + action_count))
|
||||||
|
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||||
|
+ [11]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[token_ids] * batch_size,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||||
|
image = image_tensor.detach().cpu()
|
||||||
|
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||||
|
image = image.permute(1, 2, 0)
|
||||||
|
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||||
|
return Image.fromarray(image)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVideoEncoder(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(1))
|
||||||
|
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||||
|
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||||
|
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVideoProcessor:
|
||||||
|
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
||||||
|
assert return_tensors == "pt"
|
||||||
|
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||||
|
|
||||||
|
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
modeling_vla_jepa.AutoModel,
|
||||||
|
"from_pretrained",
|
||||||
|
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
modeling_vla_jepa.AutoVideoProcessor,
|
||||||
|
"from_pretrained",
|
||||||
|
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||||
|
)
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("diffusers")
|
||||||
|
|
||||||
|
from conftest import (
|
||||||
|
ACTION_DIM,
|
||||||
|
ACTION_HORIZON,
|
||||||
|
BATCH_SIZE,
|
||||||
|
QWEN_HIDDEN_SIZE,
|
||||||
|
STATE_DIM,
|
||||||
|
make_config,
|
||||||
|
set_seed_all,
|
||||||
|
) # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||||
|
VLAJEPAActionHead,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VLAJEPAActionHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4), # default test dims
|
||||||
|
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||||
|
(6, 8, 8), # medium dims
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||||
|
assert t.shape == (200,)
|
||||||
|
assert torch.isfinite(t).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(2, action_horizon, action_dim)
|
||||||
|
timesteps = torch.randint(0, 100, (2,))
|
||||||
|
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||||
|
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||||
|
|
||||||
|
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||||
|
if state_dim > 0:
|
||||||
|
assert out_with.shape[1] > out_none.shape[1]
|
||||||
|
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(2, action_horizon, action_dim)
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
loss = head.forward(conditioning, actions, state)
|
||||||
|
assert loss.shape == ()
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_head_forward_gradient_flows() -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config()
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||||
|
loss = head.forward(conditioning, actions, state)
|
||||||
|
loss.backward()
|
||||||
|
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
pred = head.predict_action(conditioning, state)
|
||||||
|
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||||
|
assert torch.isfinite(pred).all()
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
def test_delta_indices() -> None:
|
||||||
|
config = make_config()
|
||||||
|
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||||
|
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||||
|
|
||||||
|
|
||||||
|
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||||
|
with pytest.raises(ValueError, match="n_action_steps"):
|
||||||
|
VLAJEPAConfig(chunk_size=4, n_action_steps=8, future_action_window_size=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_future_window_exceeds_chunk_size_raises() -> None:
|
||||||
|
with pytest.raises(ValueError, match="predicted action horizon"):
|
||||||
|
VLAJEPAConfig(chunk_size=4, n_action_steps=4, future_action_window_size=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_too_few_video_frames_raises() -> None:
|
||||||
|
with pytest.raises(ValueError, match="video_horizon"):
|
||||||
|
VLAJEPAConfig(
|
||||||
|
chunk_size=16,
|
||||||
|
n_action_steps=16,
|
||||||
|
future_action_window_size=15,
|
||||||
|
num_video_frames=2,
|
||||||
|
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_no_image_raises() -> None:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||||
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_no_action_raises() -> None:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={
|
||||||
|
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
},
|
||||||
|
output_features={},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="action output feature"):
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||||
|
config = make_config(action_dim=6, state_dim=10)
|
||||||
|
assert config.action_dim == 6
|
||||||
|
assert config.state_dim == 10
|
||||||
@@ -4,18 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from torch import Tensor
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
||||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|
||||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|
||||||
|
|
||||||
pytest.importorskip("transformers")
|
pytest.importorskip("transformers")
|
||||||
pytest.importorskip("diffusers")
|
pytest.importorskip("diffusers")
|
||||||
@@ -24,190 +16,33 @@ pytestmark = pytest.mark.filterwarnings(
|
|||||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from conftest import ( # noqa: E402
|
||||||
|
ACTION_DIM,
|
||||||
|
ACTION_HORIZON,
|
||||||
|
BATCH_SIZE,
|
||||||
|
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||||
|
EXPECTED_SELECT_ACTION_SHAPE,
|
||||||
|
N_ACTION_STEPS,
|
||||||
|
STATE_DIM,
|
||||||
|
make_config,
|
||||||
|
make_inference_batch,
|
||||||
|
make_train_batch,
|
||||||
|
set_seed_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||||
|
from lerobot.utils.constants import ACTION # noqa: E402
|
||||||
|
|
||||||
BATCH_SIZE = 2
|
|
||||||
ACTION_DIM = 3
|
|
||||||
STATE_DIM = 4
|
|
||||||
IMAGE_SIZE = 8
|
|
||||||
ACTION_HORIZON = 4
|
|
||||||
N_ACTION_STEPS = 2
|
|
||||||
NUM_VIDEO_FRAMES = 3
|
|
||||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
|
||||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
|
||||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||||
|
|
||||||
|
|
||||||
def set_seed_all(seed: int) -> None:
|
# ---------------------------------------------------------------------------
|
||||||
np.random.seed(seed)
|
# Core training / inference tests
|
||||||
torch.manual_seed(seed)
|
# ---------------------------------------------------------------------------
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeQwenBackbone(nn.Module):
|
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||||
def __init__(self, hidden_size: int) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(1))
|
|
||||||
self.config = SimpleNamespace(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return self.weight.device
|
|
||||||
|
|
||||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
values = torch.arange(
|
|
||||||
batch_size * seq_len * hidden_size,
|
|
||||||
device=input_ids.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
).view(batch_size, seq_len, hidden_size)
|
|
||||||
hidden = values / values.numel() + self.weight
|
|
||||||
return SimpleNamespace(hidden_states=[hidden])
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeQwenInterface(nn.Module):
|
|
||||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.model = _FakeQwenBackbone(hidden_size=16)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
|
||||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
|
||||||
|
|
||||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
|
||||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
|
||||||
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
|
||||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
|
||||||
return action_tokens, action_token_ids, 2000
|
|
||||||
|
|
||||||
def build_inputs(
|
|
||||||
self,
|
|
||||||
images: list[list[Image.Image]],
|
|
||||||
instructions: list[str],
|
|
||||||
action_prompt: str,
|
|
||||||
embodied_prompt: str,
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
batch_size = len(images)
|
|
||||||
del images, instructions, action_prompt, embodied_prompt
|
|
||||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
|
||||||
token_ids = (
|
|
||||||
[10]
|
|
||||||
+ list(range(1000, 1000 + action_count))
|
|
||||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
|
||||||
+ [11]
|
|
||||||
)
|
|
||||||
input_ids = torch.tensor(
|
|
||||||
[token_ids] * batch_size,
|
|
||||||
device=self.model.device,
|
|
||||||
dtype=torch.long,
|
|
||||||
)
|
|
||||||
return {"input_ids": input_ids}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
|
||||||
image = image_tensor.detach().cpu()
|
|
||||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
|
||||||
image = image.permute(1, 2, 0)
|
|
||||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
|
||||||
return Image.fromarray(image)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeVideoEncoder(nn.Module):
|
|
||||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(1))
|
|
||||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return self.weight.device
|
|
||||||
|
|
||||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
|
||||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
|
||||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeVideoProcessor:
|
|
||||||
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
|
||||||
assert return_tensors == "pt"
|
|
||||||
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
||||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
|
||||||
|
|
||||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
modeling_vla_jepa.AutoModel,
|
|
||||||
"from_pretrained",
|
|
||||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
modeling_vla_jepa.AutoVideoProcessor,
|
|
||||||
"from_pretrained",
|
|
||||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_config() -> VLAJEPAConfig:
|
|
||||||
config = VLAJEPAConfig(
|
|
||||||
input_features={
|
|
||||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
|
||||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
|
||||||
},
|
|
||||||
output_features={
|
|
||||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
|
||||||
},
|
|
||||||
device="cpu",
|
|
||||||
chunk_size=ACTION_HORIZON,
|
|
||||||
n_action_steps=N_ACTION_STEPS,
|
|
||||||
future_action_window_size=ACTION_HORIZON - 1,
|
|
||||||
action_dim=ACTION_DIM,
|
|
||||||
state_dim=STATE_DIM,
|
|
||||||
num_video_frames=NUM_VIDEO_FRAMES,
|
|
||||||
num_action_tokens_per_timestep=2,
|
|
||||||
num_embodied_action_tokens_per_instruction=3,
|
|
||||||
num_inference_timesteps=2,
|
|
||||||
action_hidden_size=16,
|
|
||||||
action_num_layers=1,
|
|
||||||
action_num_heads=2,
|
|
||||||
action_attention_head_dim=8,
|
|
||||||
predictor_depth=1,
|
|
||||||
predictor_num_heads=2,
|
|
||||||
predictor_mlp_ratio=2.0,
|
|
||||||
)
|
|
||||||
config.validate_features()
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
|
||||||
return {
|
|
||||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
||||||
OBS_STATE: torch.randn(batch_size, 1, STATE_DIM),
|
|
||||||
ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM),
|
|
||||||
"task": ["pick up the cube"] * batch_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
|
||||||
return {
|
|
||||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
||||||
OBS_STATE: torch.randn(batch_size, STATE_DIM),
|
|
||||||
"task": ["pick up the cube"] * batch_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
|
||||||
set_seed_all(42)
|
set_seed_all(42)
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -224,9 +59,8 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
|
|||||||
assert logs["wm_loss"] >= 0
|
assert logs["wm_loss"] >= 0
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
assert any(
|
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||||
param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad
|
# Batch must not be mutated.
|
||||||
)
|
|
||||||
assert set(batch) == set(batch_before)
|
assert set(batch) == set(batch_before)
|
||||||
for key, value in batch.items():
|
for key, value in batch.items():
|
||||||
if isinstance(value, Tensor):
|
if isinstance(value, Tensor):
|
||||||
@@ -235,34 +69,75 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
|
|||||||
assert value == batch_before[key]
|
assert value == batch_before[key]
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||||
def test_vla_jepa_action_generation_shape(
|
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.train()
|
||||||
|
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_training_forward_various_dims(
|
||||||
patch_vla_jepa_external_models: None,
|
patch_vla_jepa_external_models: None,
|
||||||
|
action_dim: int,
|
||||||
|
state_dim: int,
|
||||||
|
action_horizon: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
policy = VLAJEPAPolicy(config)
|
||||||
|
policy.train()
|
||||||
|
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
loss, _ = policy.forward(batch)
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||||
set_seed_all(42)
|
set_seed_all(42)
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
policy.eval()
|
policy.eval()
|
||||||
batch = make_inference_batch()
|
batch = make_inference_batch()
|
||||||
|
|
||||||
action_chunk = policy.predict_action_chunk(batch)
|
chunk = policy.predict_action_chunk(batch)
|
||||||
|
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||||
|
assert chunk.device.type == "cpu"
|
||||||
|
assert torch.isfinite(chunk).all()
|
||||||
|
|
||||||
assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
a1 = policy.select_action(batch)
|
||||||
assert action_chunk.device.type == "cpu"
|
a2 = policy.select_action(batch)
|
||||||
assert torch.isfinite(action_chunk).all()
|
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
first_action = policy.select_action(batch)
|
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||||
second_action = policy.select_action(batch)
|
|
||||||
|
|
||||||
assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
|
||||||
assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
|
||||||
assert torch.isfinite(first_action).all()
|
|
||||||
assert torch.isfinite(second_action).all()
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_vla_jepa_inference_reproducibility(
|
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||||
patch_vla_jepa_external_models: None,
|
def test_action_generation_various_dims(
|
||||||
|
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||||
) -> None:
|
) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||||
|
policy = VLAJEPAPolicy(config)
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch(state_dim=state_dim)
|
||||||
|
chunk = policy.predict_action_chunk(batch)
|
||||||
|
assert chunk.shape[-1] == action_dim
|
||||||
|
assert torch.isfinite(chunk).all()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||||
set_seed_all(42)
|
set_seed_all(42)
|
||||||
policy = VLAJEPAPolicy(make_config())
|
policy = VLAJEPAPolicy(make_config())
|
||||||
policy.eval()
|
policy.eval()
|
||||||
@@ -270,7 +145,6 @@ def test_vla_jepa_inference_reproducibility(
|
|||||||
|
|
||||||
set_seed_all(123)
|
set_seed_all(123)
|
||||||
actions_1 = policy.predict_action_chunk(batch)
|
actions_1 = policy.predict_action_chunk(batch)
|
||||||
|
|
||||||
set_seed_all(123)
|
set_seed_all(123)
|
||||||
actions_2 = policy.predict_action_chunk(batch)
|
actions_2 = policy.predict_action_chunk(batch)
|
||||||
|
|
||||||
@@ -278,7 +152,125 @@ def test_vla_jepa_inference_reproducibility(
|
|||||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
@torch.no_grad()
|
||||||
|
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
for seed in [0, 42, 123]:
|
||||||
|
set_seed_all(seed)
|
||||||
|
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||||
|
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Action queue behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch()
|
||||||
|
|
||||||
|
# First call fills the queue (n_action_steps items) and pops one.
|
||||||
|
a1 = policy.select_action(batch)
|
||||||
|
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||||
|
|
||||||
|
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||||
|
a2 = policy.select_action(batch)
|
||||||
|
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
policy.select_action(make_inference_batch())
|
||||||
|
assert len(policy._queues[ACTION]) > 0
|
||||||
|
|
||||||
|
policy.reset()
|
||||||
|
assert len(policy._queues[ACTION]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Format conversion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
examples = policy._lerobot_to_native(make_train_batch())
|
||||||
|
|
||||||
|
assert len(examples) == BATCH_SIZE
|
||||||
|
for ex in examples:
|
||||||
|
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||||
|
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||||
|
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||||
|
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||||
|
assert ex["state"].shape == (1, STATE_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
for ex in policy._lerobot_to_native(make_inference_batch()):
|
||||||
|
assert "action" not in ex
|
||||||
|
assert "image" in ex and "video" in ex and "lang" in ex
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
del batch["task"]
|
||||||
|
examples = policy._lerobot_to_native(batch)
|
||||||
|
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
batch["task"] = "open the drawer"
|
||||||
|
assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch))
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
from lerobot.utils.constants import OBS_STATE
|
||||||
|
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
del batch[OBS_STATE]
|
||||||
|
assert all("state" not in ex for ex in policy._lerobot_to_native(batch))
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)})
|
||||||
|
assert torch.isfinite(loss)
|
||||||
|
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||||
|
assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5)
|
||||||
|
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_to_lerobot_wm_only(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
_, logs = policy._native_to_lerobot({"wm_loss": torch.tensor(0.3)})
|
||||||
|
assert "action_loss" not in logs
|
||||||
|
assert logs["wm_loss"] == pytest.approx(0.3, abs=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pretrained checkpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
||||||
|
import torch
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||||
|
|
||||||
@@ -291,12 +283,10 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint_path = hf_hub_download(
|
checkpoint_path = hf_hub_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id, filename=checkpoint_filename, local_files_only=True
|
||||||
filename=checkpoint_filename,
|
|
||||||
local_files_only=True,
|
|
||||||
)
|
)
|
||||||
except LocalEntryNotFoundError:
|
except LocalEntryNotFoundError:
|
||||||
pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.")
|
pytest.skip(f"{repo_id}/{checkpoint_filename} is not in the local HF cache.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
||||||
@@ -309,7 +299,6 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
|||||||
or checkpoint.get("model")
|
or checkpoint.get("model")
|
||||||
or checkpoint
|
or checkpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(state_dict, dict)
|
assert isinstance(state_dict, dict)
|
||||||
assert len(state_dict) > 0
|
assert len(state_dict) > 0
|
||||||
assert all(isinstance(key, str) for key in list(state_dict)[:10])
|
assert all(isinstance(k, str) for k in list(state_dict)[:10])
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.world_model import (
|
||||||
|
ActionConditionedVideoPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACTION_EMBED_DIM = 8
|
||||||
|
|
||||||
|
|
||||||
|
def _make_predictor(
|
||||||
|
embed_dim: int = 8,
|
||||||
|
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||||
|
predictor_embed_dim: int = 16,
|
||||||
|
num_action_tokens: int = 2,
|
||||||
|
) -> ActionConditionedVideoPredictor:
|
||||||
|
return ActionConditionedVideoPredictor(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
action_embed_dim=action_embed_dim,
|
||||||
|
predictor_embed_dim=predictor_embed_dim,
|
||||||
|
depth=1,
|
||||||
|
num_heads=2,
|
||||||
|
mlp_ratio=2.0,
|
||||||
|
num_action_tokens_per_step=num_action_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||||
|
[
|
||||||
|
(1, 2, 1, 8),
|
||||||
|
(2, 3, 4, 8),
|
||||||
|
(4, 5, 2, 16),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||||
|
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM)
|
||||||
|
frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim)
|
||||||
|
action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM)
|
||||||
|
out = predictor(frame_tokens, action_tokens)
|
||||||
|
assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim)
|
||||||
|
assert torch.isfinite(out).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_step_mismatch_raises() -> None:
|
||||||
|
predictor = _make_predictor()
|
||||||
|
frame_tokens = torch.randn(2, 3, 4, 8)
|
||||||
|
with pytest.raises(ValueError, match="Expected 3 action steps"):
|
||||||
|
predictor(frame_tokens, torch.randn(2, 2, 2, 8))
|
||||||
Reference in New Issue
Block a user