mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
#!/usr/bin/env python
|
|
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from types import SimpleNamespace
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared constants
|
|
# ---------------------------------------------------------------------------
|
|
|
|
BATCH_SIZE = 2
|
|
ACTION_DIM = 3
|
|
STATE_DIM = 4
|
|
IMAGE_SIZE = 8
|
|
ACTION_HORIZON = 4
|
|
N_ACTION_STEPS = 2
|
|
NUM_VIDEO_FRAMES = 3
|
|
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
|
|
|
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
|
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def set_seed_all(seed: int) -> None:
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def make_config(
|
|
action_dim: int = ACTION_DIM,
|
|
state_dim: int = STATE_DIM,
|
|
action_horizon: int = ACTION_HORIZON,
|
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
|
) -> VLAJEPAConfig:
|
|
config = VLAJEPAConfig(
|
|
input_features={
|
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
|
},
|
|
output_features={
|
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
|
},
|
|
device="cpu",
|
|
chunk_size=action_horizon,
|
|
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
|
action_dim=action_dim,
|
|
state_dim=state_dim,
|
|
num_video_frames=num_video_frames,
|
|
num_action_tokens_per_timestep=2,
|
|
num_embodied_action_tokens_per_instruction=3,
|
|
num_inference_timesteps=2,
|
|
action_hidden_size=QWEN_HIDDEN_SIZE,
|
|
action_model_type="DiT-test",
|
|
action_num_layers=1,
|
|
predictor_depth=1,
|
|
predictor_num_heads=2,
|
|
predictor_mlp_ratio=2.0,
|
|
jepa_tubelet_size=1,
|
|
)
|
|
config.validate_features()
|
|
return config
|
|
|
|
|
|
def make_train_batch(
|
|
batch_size: int = BATCH_SIZE,
|
|
action_dim: int = ACTION_DIM,
|
|
state_dim: int = STATE_DIM,
|
|
action_horizon: int = ACTION_HORIZON,
|
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
|
) -> dict[str, Tensor | list[str]]:
|
|
return {
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
|
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
|
"task": ["pick up the cube"] * batch_size,
|
|
}
|
|
|
|
|
|
def make_inference_batch(
|
|
batch_size: int = BATCH_SIZE,
|
|
state_dim: int = STATE_DIM,
|
|
) -> dict[str, Tensor | list[str]]:
|
|
return {
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
OBS_STATE: torch.randn(batch_size, state_dim),
|
|
"task": ["pick up the cube"] * batch_size,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeQwenBackbone(nn.Module):
|
|
def __init__(self, hidden_size: int) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(1))
|
|
self.config = SimpleNamespace(
|
|
hidden_size=hidden_size,
|
|
text_config=SimpleNamespace(hidden_size=hidden_size),
|
|
)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.weight.device
|
|
|
|
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
|
batch_size, seq_len = input_ids.shape
|
|
hidden_size = self.config.hidden_size
|
|
values = torch.arange(
|
|
batch_size * seq_len * hidden_size,
|
|
device=input_ids.device,
|
|
dtype=torch.float32,
|
|
).view(batch_size, seq_len, hidden_size)
|
|
hidden = values / values.numel() + self.weight
|
|
return SimpleNamespace(hidden_states=[hidden])
|
|
|
|
|
|
class _FakeQwenInterface(nn.Module):
|
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
|
|
|
@staticmethod
|
|
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
|
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
|
|
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
|
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
|
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
|
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
|
return action_tokens, action_token_ids, 2000
|
|
|
|
def build_inputs(
|
|
self,
|
|
images: list[list[Image.Image]],
|
|
instructions: list[str],
|
|
action_prompt: str,
|
|
embodied_prompt: str,
|
|
) -> dict[str, Tensor]:
|
|
batch_size = len(images)
|
|
del images, instructions, action_prompt, embodied_prompt
|
|
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
|
token_ids = (
|
|
[10]
|
|
+ list(range(1000, 1000 + action_count))
|
|
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
|
+ [11]
|
|
)
|
|
return {
|
|
"input_ids": torch.tensor(
|
|
[token_ids] * batch_size,
|
|
device=self.model.device,
|
|
dtype=torch.long,
|
|
)
|
|
}
|
|
|
|
@staticmethod
|
|
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
|
image = image_tensor.detach().cpu()
|
|
if image.ndim == 3 and image.shape[0] in (1, 3):
|
|
image = image.permute(1, 2, 0)
|
|
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
|
return Image.fromarray(image)
|
|
|
|
|
|
class _FakeVideoEncoder(nn.Module):
|
|
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(1))
|
|
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.weight.device
|
|
|
|
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
|
batch_size, num_frames = pixel_values_videos.shape[:2]
|
|
hidden_size = self.config.hidden_size
|
|
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
|
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
|
|
|
|
|
class _FakeVideoProcessor:
|
|
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
|
assert return_tensors == "pt"
|
|
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
|
|
|
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
|
monkeypatch.setattr(
|
|
modeling_vla_jepa.AutoModel,
|
|
"from_pretrained",
|
|
lambda *args, **kwargs: _FakeVideoEncoder(),
|
|
)
|
|
monkeypatch.setattr(
|
|
modeling_vla_jepa.AutoVideoProcessor,
|
|
"from_pretrained",
|
|
lambda *args, **kwargs: _FakeVideoProcessor(),
|
|
)
|