mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
support vla_jepa
This commit is contained in:
committed by
Maxime Ellerbach
parent
e93fd2bcfe
commit
2757266f6b
@@ -56,6 +56,7 @@ from .sac.configuration_sac import SACConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -406,6 +413,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -414,6 +422,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -203,7 +202,9 @@ class VLAJEPAActionHead(nn.Module):
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size)
|
||||
self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
@@ -268,7 +269,9 @@ class VLAJEPAActionHead(nn.Module):
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -27,12 +27,14 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
|
||||
prompt_template: str = (
|
||||
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
|
||||
)
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 7
|
||||
state_dim: int = 8
|
||||
future_action_window_size: int = 15
|
||||
past_action_window_size: int = 0
|
||||
num_action_tokens_per_timestep: int = 4
|
||||
@@ -42,7 +44,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 12
|
||||
action_num_heads: int = 12
|
||||
action_num_heads: int = 16
|
||||
action_attention_head_dim: int = 64
|
||||
action_dropout: float = 0.1
|
||||
action_num_timestep_buckets: int = 1000
|
||||
|
||||
@@ -3,29 +3,60 @@ from __future__ import annotations
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer()
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
@@ -40,144 +71,224 @@ class VLAJEPAModel(nn.Module):
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
|
||||
# Build prompt placeholders (same as original)
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[: self.config.num_video_frames - 1]
|
||||
)
|
||||
self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _collect_images(self, batch: dict[str, Tensor]) -> list[list]:
|
||||
sample_key = self.config.image_features[0]
|
||||
batch_size = batch[sample_key].shape[0]
|
||||
images = [[] for _ in range(batch_size)]
|
||||
for key in self.config.image_features:
|
||||
tensor = batch[key]
|
||||
if tensor.ndim == 5:
|
||||
tensor = tensor[:, -1]
|
||||
for idx in range(batch_size):
|
||||
images[idx].append(self.qwen.tensor_to_pil(tensor[idx]))
|
||||
return images
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||
first_key = self.config.image_features[0]
|
||||
source = batch[first_key]
|
||||
if source.ndim == 4:
|
||||
source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1)
|
||||
elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames:
|
||||
pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1)
|
||||
source = torch.cat([source, pad], dim=1)
|
||||
elif source.ndim == 5:
|
||||
source = source[:, -self.config.num_video_frames :]
|
||||
else:
|
||||
raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}")
|
||||
return source
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
return ["Execute the robot action."] * next(iter(batch.values())).shape[0]
|
||||
if isinstance(tasks, str):
|
||||
return [tasks]
|
||||
return list(tasks)
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
images = self._collect_images(batch)
|
||||
tasks = self._get_tasks(batch)
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=images,
|
||||
instructions=tasks,
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
hidden = outputs.hidden_states[-1]
|
||||
|
||||
# Locate action and embodied-action tokens in the tokenized sequence
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
|
||||
return action_tokens, embodied_tokens
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
||||
if OBS_STATE not in batch:
|
||||
return None
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
return state.to(device=device, dtype=dtype)
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
|
||||
B, _, H = last_hidden.shape
|
||||
|
||||
def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
horizon = self.config.future_action_window_size + 1
|
||||
if actions.shape[1] < horizon:
|
||||
pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1)
|
||||
actions = torch.cat([actions, pad], dim=1)
|
||||
return actions[:, -horizon:].to(device=device, dtype=dtype)
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H)
|
||||
|
||||
def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor:
|
||||
processed = []
|
||||
for sample in video_tensor:
|
||||
processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"]
|
||||
processed.append(processed_sample)
|
||||
pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device)
|
||||
return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values)
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
|
||||
|
||||
def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None:
|
||||
if not self.config.enable_world_model:
|
||||
return None
|
||||
video_tensor = self._collect_videos(batch)
|
||||
video_features = self._encode_video(video_tensor)
|
||||
batch_size = video_tensor.shape[0]
|
||||
num_frames = video_tensor.shape[1]
|
||||
tokens_per_frame = video_features.shape[1] // num_frames
|
||||
video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1)
|
||||
input_states = video_features[:, :-1]
|
||||
gt_states = video_features[:, 1:]
|
||||
# ---- Step 2: JEPA Encoder (same as original) ----
|
||||
B, V, T_frames, C, H_img, W_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img)
|
||||
|
||||
expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_tokens:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
action_tokens = action_tokens[:, :expected_tokens]
|
||||
action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1)
|
||||
pred_states = self.video_predictor(input_states, action_tokens)
|
||||
return F.l1_loss(pred_states, gt_states, reduction="mean")
|
||||
video_pixels = []
|
||||
for i in range(B * V):
|
||||
video_pixels.append(
|
||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device)
|
||||
)
|
||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch)
|
||||
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
action_loss = self.action_model(embodied_tokens, target_actions, state)
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2)
|
||||
|
||||
wm_loss = self._compute_world_model_loss(batch, action_tokens)
|
||||
total_loss = action_loss
|
||||
logs = {"action_loss": action_loss.detach()}
|
||||
if wm_loss is not None:
|
||||
total_loss = total_loss + self.config.world_model_loss_weight * wm_loss
|
||||
logs["wm_loss"] = wm_loss.detach()
|
||||
logs["loss"] = total_loss.detach()
|
||||
return total_loss, logs
|
||||
# ---- Step 3: JEPA Predictor (same as original) ----
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
T_enc = T_frames // tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
|
||||
if T_enc < 2:
|
||||
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
tokens_per_frame = video_embeddings.shape[1] // T_enc
|
||||
|
||||
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
|
||||
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
|
||||
input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
D_emb = input_states.shape[-1]
|
||||
|
||||
# Reshape to 4D for ActionConditionedVideoPredictor:
|
||||
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
|
||||
input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb)
|
||||
|
||||
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
|
||||
expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
act_4d = action_tokens[:, :expected_actions].view(
|
||||
B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1
|
||||
)
|
||||
|
||||
# Cast to float32 for predictor (Linear layers are float32)
|
||||
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
||||
predicted_states = pred_4d.reshape(B, -1, D_emb)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head (same as original) ----
|
||||
with torch.autocast("cuda", dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.future_action_window_size + 1
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
# Cast embodied tokens to float32 for action model compatibility
|
||||
action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
_, embodied_tokens = self._extract_qwen_conditioning(batch)
|
||||
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
return self.action_model.predict_action(embodied_tokens, state)
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
last_hidden = qwen_outputs.hidden_states[-1]
|
||||
B, _, H = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.float32):
|
||||
# Cast embodied tokens to float32 for action model compatibility
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
@@ -190,22 +301,193 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _lerobot_to_native(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# Multi-frame: take the last frame as the "current" image
|
||||
tensor = tensor[:, -1]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
num_views = len(image_keys)
|
||||
has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch)
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = t_np.clip(0, 255).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
if ACTION in batch:
|
||||
actions_tensor = batch[ACTION] # [B, chunk_size, action_dim]
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
if OBS_STATE in batch:
|
||||
state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim]
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- Format Conversion: Native → LeRobot ----
|
||||
|
||||
def _native_to_lerobot(self, native_output: dict[str, Tensor]) -> tuple[Tensor, dict[str, float]]:
|
||||
"""
|
||||
Convert native VLA-JEPA output dict to LeRobot (loss, logs) format.
|
||||
|
||||
Native output:
|
||||
{"action_loss": Tensor, "wm_loss": Tensor}
|
||||
or {"wm_loss": Tensor} (video-only mode)
|
||||
|
||||
LeRobot output:
|
||||
(total_loss: scalar Tensor, {"action_loss": float, "wm_loss": float, "loss": float})
|
||||
"""
|
||||
logs: dict[str, float] = {}
|
||||
total_loss = torch.tensor(0.0, device=self.config.device)
|
||||
|
||||
if "action_loss" in native_output:
|
||||
total_loss = total_loss + native_output["action_loss"]
|
||||
logs["action_loss"] = native_output["action_loss"].detach().item()
|
||||
|
||||
if "wm_loss" in native_output:
|
||||
wm_loss = native_output["wm_loss"]
|
||||
logs["wm_loss"] = wm_loss.detach().item()
|
||||
|
||||
logs["loss"] = (
|
||||
total_loss.detach().item()
|
||||
if total_loss.item() != 0
|
||||
else (logs.get("wm_loss", 0.0) + logs.get("action_loss", 0.0))
|
||||
)
|
||||
|
||||
return total_loss, logs
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
loss, logs = self.model(batch)
|
||||
return loss, {key: value.item() for key, value in logs.items()}
|
||||
"""LeRobot train forward: convert → native forward → convert back."""
|
||||
examples = self._lerobot_to_native(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
return self._native_to_lerobot(native_output)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
return self.model.predict_action(batch)
|
||||
|
||||
# Convert to native format
|
||||
examples = self._lerobot_to_native(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
# Call native predict
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
|
||||
# Convert back to tensor on the right device
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.model.predict_action(batch)
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user