mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
(feat)policies: add VLA-JEPA
This commit is contained in:
@@ -0,0 +1,320 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||
)
|
||||
|
||||
|
||||
BATCH_SIZE = 2
|
||||
ACTION_DIM = 3
|
||||
STATE_DIM = 4
|
||||
IMAGE_SIZE = 8
|
||||
ACTION_HORIZON = 4
|
||||
N_ACTION_STEPS = 2
|
||||
NUM_VIDEO_FRAMES = 3
|
||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
|
||||
def set_seed_all(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
class _FakeQwenBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self.config.hidden_size
|
||||
values = torch.arange(
|
||||
batch_size * seq_len * hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
).view(batch_size, seq_len, hidden_size)
|
||||
hidden = values / values.numel() + self.weight
|
||||
return SimpleNamespace(hidden_states=[hidden])
|
||||
|
||||
|
||||
class _FakeQwenInterface(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = _FakeQwenBackbone(hidden_size=16)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
action_tokens = [
|
||||
self.config.special_action_token.format(idx)
|
||||
for idx in range(max_action_tokens)
|
||||
]
|
||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||
return action_tokens, action_token_ids, 2000
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, Tensor]:
|
||||
batch_size = len(images)
|
||||
del images, instructions, action_prompt, embodied_prompt
|
||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
token_ids = (
|
||||
[10]
|
||||
+ list(range(1000, 1000 + action_count))
|
||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||
+ [11]
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[token_ids] * batch_size,
|
||||
device=self.model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||
hidden_size = self.config.hidden_size
|
||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||
|
||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoModel,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoVideoProcessor,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||
)
|
||||
|
||||
|
||||
def make_config() -> VLAJEPAConfig:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
},
|
||||
device="cpu",
|
||||
chunk_size=ACTION_HORIZON,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
future_action_window_size=ACTION_HORIZON - 1,
|
||||
action_dim=ACTION_DIM,
|
||||
state_dim=STATE_DIM,
|
||||
num_video_frames=NUM_VIDEO_FRAMES,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=16,
|
||||
action_num_layers=1,
|
||||
action_num_heads=2,
|
||||
action_attention_head_dim=8,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def make_train_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, NUM_VIDEO_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, 1, STATE_DIM),
|
||||
ACTION: torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def make_inference_batch(batch_size: int = BATCH_SIZE) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, STATE_DIM),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
|
||||
batch = make_train_batch()
|
||||
batch_before = deepcopy(batch)
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
assert logs["action_loss"] > 0
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
loss.backward()
|
||||
assert any(
|
||||
param.grad is not None
|
||||
for param in policy.model.action_model.parameters()
|
||||
if param.requires_grad
|
||||
)
|
||||
assert set(batch) == set(batch_before)
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, Tensor):
|
||||
assert torch.equal(value, batch_before[key])
|
||||
else:
|
||||
assert value == batch_before[key]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_vla_jepa_action_generation_shape(
|
||||
patch_vla_jepa_external_models: None,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
action_chunk = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(action_chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert action_chunk.device.type == "cpu"
|
||||
assert torch.isfinite(action_chunk).all()
|
||||
|
||||
first_action = policy.select_action(batch)
|
||||
second_action = policy.select_action(batch)
|
||||
|
||||
assert tuple(first_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(second_action.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert torch.isfinite(first_action).all()
|
||||
assert torch.isfinite(second_action).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_vla_jepa_inference_reproducibility(
|
||||
patch_vla_jepa_external_models: None,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
set_seed_all(123)
|
||||
actions_1 = policy.predict_action_chunk(batch)
|
||||
|
||||
set_seed_all(123)
|
||||
actions_2 = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||
|
||||
|
||||
def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
|
||||
repo_id = os.environ.get("VLA_JEPA_PRETRAINED_REPO_ID", PRETRAINED_REPO_ID)
|
||||
subfolder = os.environ.get("VLA_JEPA_PRETRAINED_SUBFOLDER", PRETRAINED_SUBFOLDER).strip("/")
|
||||
checkpoint_filename = os.environ.get(
|
||||
"VLA_JEPA_PRETRAINED_CHECKPOINT",
|
||||
f"{subfolder}/checkpoints/VLA-JEPA-{subfolder}.pt",
|
||||
)
|
||||
|
||||
try:
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=checkpoint_filename,
|
||||
local_files_only=True,
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
pytest.skip(
|
||||
f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache."
|
||||
)
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
state_dict = (
|
||||
checkpoint.get("state_dict")
|
||||
or checkpoint.get("model_state_dict")
|
||||
or checkpoint.get("model")
|
||||
or checkpoint
|
||||
)
|
||||
|
||||
assert isinstance(state_dict, dict)
|
||||
assert len(state_dict) > 0
|
||||
assert all(isinstance(key, str) for key in list(state_dict)[:10])
|
||||
Reference in New Issue
Block a user