mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
321 lines
10 KiB
Python
321 lines
10 KiB
Python
#!/usr/bin/env python
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from copy import deepcopy
|
|
from types import SimpleNamespace
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
from torch import Tensor, nn
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
|
|
|
|
|
pytestmark = pytest.mark.filterwarnings(
|
|
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
|
)
|
|
|
|
|
|
BATCH_SIZE = 2
|
|
ACTION_DIM = 3
|
|
STATE_DIM = 4
|
|
IMAGE_SIZE = 8
|
|
ACTION_HORIZON = 4
|
|
N_ACTION_STEPS = 2
|
|
NUM_VIDEO_FRAMES = 3
|
|
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
|
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
|
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
|
PRETRAINED_SUBFOLDER = "LIBERO"
|
|
|
|
|
|
def set_seed_all(seed: int) -> None:
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
class _FakeQwenBackbone(nn.Module):
|
|
def __init__(self, hidden_size: int) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(1))
|
|
self.config = SimpleNamespace(
|
|
hidden_size=hidden_size,
|
|
text_config=SimpleNamespace(hidden_size=hidden_size),
|
|
)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.weight.device
|
|
|
|
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
|
batch_size, seq_len = input_ids.shape
|
|
hidden_size = self.config.hidden_size
|
|
values = torch.arange(
|
|
batch_size * seq_len * hidden_size,
|
|
device=input_ids.device,
|
|
dtype=torch.float32,
|
|
).view(batch_size, seq_len, hidden_size)
|
|
hidden = values / values.numel() + self.weight
|
|
return SimpleNamespace(hidden_states=[hidden])
|
|
|
|
|
|
class _FakeQwenInterface(nn.Module):
|
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.model = _FakeQwenBackbone(hidden_size=16)
|
|
|
|
@staticmethod
|
|
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
|
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
|
|
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
|
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
|
action_tokens = [
|
|
self.config.special_action_token.format(idx)
|
|
for idx in range(max_action_tokens)
|
|
]
|
|
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
|
return action_tokens, action_token_ids, 2000
|
|
|
|
def build_inputs(
|
|
self,
|
|
images: list[list[Image.Image]],
|
|
instructions: list[str],
|
|
action_prompt: str,
|
|
embodied_prompt: str,
|
|
) -> dict[str, Tensor]:
|
|
batch_size = len(images)
|
|
del images, instructions, action_prompt, embodied_prompt
|
|
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
|
token_ids = (
|
|
[10]
|
|
+ list(range(1000, 1000 + action_count))
|
|
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
|
+ [11]
|
|
)
|
|
input_ids = torch.tensor(
|
|
[token_ids] * batch_size,
|
|
device=self.model.device,
|
|
dtype=torch.long,
|
|
)
|
|
return {"input_ids": input_ids}
|
|
|
|
@staticmethod
|
|
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
|
image = image_tensor.detach().cpu()
|
|
if image.ndim == 3 and image.shape[0] in (1, 3):
|
|
image = image.permute(1, 2, 0)
|
|
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
|
return Image.fromarray(image)
|
|
|
|
|
|
class _FakeVideoEncoder(nn.Module):
|
|
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(1))
|
|
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.weight.device
|
|
|
|
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
|
batch_size, num_frames = pixel_values_videos.shape[:2]
|
|
hidden_size = self.config.hidden_size
|
|
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
|
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
|
|
|
|
|
class _FakeVideoProcessor:
|
|
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
|
assert return_tensors == "pt"
|
|
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
|
|
|
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
|
monkeypatch.setattr(
|
|
modeling_vla_jepa.AutoModel,
|
|
"from_pretrained",
|
|
lambda *args, **kwargs: _FakeVideoEncoder(),
|
|
)
|
|
monkeypatch.setattr(
|
|
modeling_vla_jepa.AutoVideoProcessor,
|
|
"from_pretrained",
|
|
lambda *args, **kwargs: _FakeVideoProcessor(),
|
|
)
|
|
|
|
|
|
def make_config() -> VLAJEPAConfig:
|
|
config = VLAJEPAConfig(
|
|
input_features={
|
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
|
},
|
|
output_features={
|
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
|
},
|
|
device="cpu",
|
|
chunk_size=ACTION_HORIZON,
|
|
n_action_steps=N_ACTION_STEPS,
|
|
future_action_window_size=ACTION_HORIZON - 1,
|
|
action_dim=ACTION_DIM,
|
|
state_dim=STATE_DIM,
|
|
num_video_frames=NUM_VIDEO_FRAMES,
|
|
num_action_tokens_per_timestep=2,
|
|
num_embodied_action_tokens_per_instruction=3,
|
|
num_inference_timesteps=2,
|
|
action_hidden_size=16,
|
|
action_num_layers=1,
|
|
action_num_heads=2,
|
|
action_attention_head_dim=8,
|
|
predictor_depth=1,
|
|
predictor_num_heads=2,
|
|
predictor_mlp_ratio=2.0,
|
|
)
|
|
config.validate_features()
|
|
return config
|
|
|
|
|
|
def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
|
return {
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
OBS_STATE: torch.randn(batch_size, 1, STATE_DIM),
|
|
ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM),
|
|
"task": ["pick up the cube"] * batch_size,
|
|
}
|
|
|
|
|
|
def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
|
return {
|
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
|
OBS_STATE: torch.randn(batch_size, STATE_DIM),
|
|
"task": ["pick up the cube"] * batch_size,
|
|
}
|
|
|
|
|
|
def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
|
set_seed_all(42)
|
|
policy = VLAJEPAPolicy(make_config())
|
|
policy.train()
|
|
|
|
batch = make_train_batch()
|
|
batch_before = deepcopy(batch)
|
|
|
|
loss, logs = policy.forward(batch)
|
|
|
|
assert loss.shape == ()
|
|
assert torch.isfinite(loss)
|
|
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
|
assert logs["action_loss"] > 0
|
|
assert logs["wm_loss"] >= 0
|
|
|
|
loss.backward()
|
|
assert any(
|
|
param.grad is not None
|
|
for param in policy.model.action_model.parameters()
|
|
if param.requires_grad
|
|
)
|
|
assert set(batch) == set(batch_before)
|
|
for key, value in batch.items():
|
|
if isinstance(value, Tensor):
|
|
assert torch.equal(value, batch_before[key])
|
|
else:
|
|
assert value == batch_before[key]
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_vla_jepa_action_generation_shape(
|
|
patch_vla_jepa_external_models: None,
|
|
) -> None:
|
|
set_seed_all(42)
|
|
policy = VLAJEPAPolicy(make_config())
|
|
policy.eval()
|
|
batch = make_inference_batch()
|
|
|
|
action_chunk = policy.predict_action_chunk(batch)
|
|
|
|
assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
|
assert action_chunk.device.type == "cpu"
|
|
assert torch.isfinite(action_chunk).all()
|
|
|
|
first_action = policy.select_action(batch)
|
|
second_action = policy.select_action(batch)
|
|
|
|
assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
|
assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
|
assert torch.isfinite(first_action).all()
|
|
assert torch.isfinite(second_action).all()
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_vla_jepa_inference_reproducibility(
|
|
patch_vla_jepa_external_models: None,
|
|
) -> None:
|
|
set_seed_all(42)
|
|
policy = VLAJEPAPolicy(make_config())
|
|
policy.eval()
|
|
batch = make_inference_batch()
|
|
|
|
set_seed_all(123)
|
|
actions_1 = policy.predict_action_chunk(batch)
|
|
|
|
set_seed_all(123)
|
|
actions_2 = policy.predict_action_chunk(batch)
|
|
|
|
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
|
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
|
|
|
|
|
def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
|
|
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
|
|
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
|
|
checkpoint_filename = os.environ.get(
|
|
"VLA_JEPA_PRETRAINED_CHECKPOINT",
|
|
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
|
|
)
|
|
|
|
try:
|
|
checkpoint_path = hf_hub_download(
|
|
repo_id=repo_id,
|
|
filename=checkpoint_filename,
|
|
local_files_only=True,
|
|
)
|
|
except LocalEntryNotFoundError:
|
|
pytest.skip(
|
|
f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache."
|
|
)
|
|
|
|
try:
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
|
except TypeError:
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
state_dict = (
|
|
checkpoint.get("state_dict")
|
|
or checkpoint.get("model_state_dict")
|
|
or checkpoint.get("model")
|
|
or checkpoint
|
|
)
|
|
|
|
assert isinstance(state_dict, dict)
|
|
assert len(state_dict) > 0
|
|
assert all(isinstance(key, str) for key in list(state_dict)[:10])
|