support vla_jepa

This commit is contained in:
ginwind
2026-05-01 11:39:06 +00:00
committed by Maxime Ellerbach
parent e93fd2bcfe
commit 2757266f6b
4 changed files with 423 additions and 120 deletions
+16
View File
@@ -56,6 +56,7 @@ from .sac.configuration_sac import SACConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -406,6 +413,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -414,6 +422,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+6 -3
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
@@ -203,7 +202,9 @@ class VLAJEPAActionHead(nn.Module):
else None
)
self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size)
self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size)
self.position_embedding = nn.Embedding(
config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
@@ -268,7 +269,9 @@ class VLAJEPAActionHead(nn.Module):
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
@@ -27,12 +27,14 @@ class VLAJEPAConfig(PreTrainedConfig):
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
tokenizer_padding_side: str = "left"
prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
prompt_template: str = (
"{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
)
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 7
state_dim: int = 8
future_action_window_size: int = 15
past_action_window_size: int = 0
num_action_tokens_per_timestep: int = 4
@@ -42,7 +44,7 @@ class VLAJEPAConfig(PreTrainedConfig):
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 12
action_num_heads: int = 12
action_num_heads: int = 16
action_attention_head_dim: int = 64
action_dropout: float = 0.1
action_num_timestep_buckets: int = 1000
+396 -114
View File
@@ -3,29 +3,60 @@ from __future__ import annotations
from collections import deque
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import Tensor, nn
from transformers import AutoModel, AutoVideoProcessor
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_STATE
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer()
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
@@ -40,144 +71,224 @@ class VLAJEPAModel(nn.Module):
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
# Build prompt placeholders (same as original)
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[: self.config.num_video_frames - 1]
)
self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _collect_images(self, batch: dict[str, Tensor]) -> list[list]:
sample_key = self.config.image_features[0]
batch_size = batch[sample_key].shape[0]
images = [[] for _ in range(batch_size)]
for key in self.config.image_features:
tensor = batch[key]
if tensor.ndim == 5:
tensor = tensor[:, -1]
for idx in range(batch_size):
images[idx].append(self.qwen.tensor_to_pil(tensor[idx]))
return images
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor:
first_key = self.config.image_features[0]
source = batch[first_key]
if source.ndim == 4:
source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1)
elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames:
pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1)
source = torch.cat([source, pad], dim=1)
elif source.ndim == 5:
source = source[:, -self.config.num_video_frames :]
else:
raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}")
return source
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
tasks = batch.get("task")
if tasks is None:
return ["Execute the robot action."] * next(iter(batch.values())).shape[0]
if isinstance(tasks, str):
return [tasks]
return list(tasks)
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
images = self._collect_images(batch)
tasks = self._get_tasks(batch)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=images,
instructions=tasks,
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
hidden = outputs.hidden_states[-1]
# Locate action and embodied-action tokens in the tokenized sequence
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
return action_tokens, embodied_tokens
def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
if OBS_STATE not in batch:
return None
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return state.to(device=device, dtype=dtype)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
B, _, H = last_hidden.shape
def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
horizon = self.config.future_action_window_size + 1
if actions.shape[1] < horizon:
pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1)
actions = torch.cat([actions, pad], dim=1)
return actions[:, -horizon:].to(device=device, dtype=dtype)
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H)
def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor:
processed = []
for sample in video_tensor:
processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"]
processed.append(processed_sample)
pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device)
return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None:
if not self.config.enable_world_model:
return None
video_tensor = self._collect_videos(batch)
video_features = self._encode_video(video_tensor)
batch_size = video_tensor.shape[0]
num_frames = video_tensor.shape[1]
tokens_per_frame = video_features.shape[1] // num_frames
video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1)
input_states = video_features[:, :-1]
gt_states = video_features[:, 1:]
# ---- Step 2: JEPA Encoder (same as original) ----
B, V, T_frames, C, H_img, W_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img)
expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_tokens:
pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
action_tokens = action_tokens[:, :expected_tokens]
action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1)
pred_states = self.video_predictor(input_states, action_tokens)
return F.l1_loss(pred_states, gt_states, reduction="mean")
video_pixels = []
for i in range(B * V):
video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device)
)
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch)
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype)
action_loss = self.action_model(embodied_tokens, target_actions, state)
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2)
wm_loss = self._compute_world_model_loss(batch, action_tokens)
total_loss = action_loss
logs = {"action_loss": action_loss.detach()}
if wm_loss is not None:
total_loss = total_loss + self.config.world_model_loss_weight * wm_loss
logs["wm_loss"] = wm_loss.detach()
logs["loss"] = total_loss.detach()
return total_loss, logs
# ---- Step 3: JEPA Predictor (same as original) ----
tubelet_size = self.video_encoder.config.tubelet_size
T_enc = T_frames // tubelet_size
device_wm = video_embeddings.device
if T_enc < 2:
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
wm_loss = torch.tensor(0.0, device=device_wm)
else:
tokens_per_frame = video_embeddings.shape[1] // T_enc
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
D_emb = input_states.shape[-1]
# Reshape to 4D for ActionConditionedVideoPredictor:
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb)
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
act_4d = action_tokens[:, :expected_actions].view(
B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1
)
# Cast to float32 for predictor (Linear layers are float32)
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
predicted_states = pred_4d.reshape(B, -1, D_emb)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head (same as original) ----
with torch.autocast("cuda", dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.future_action_window_size + 1
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=torch.float32
) # [B, 1, state_dim]
# Cast embodied tokens to float32 for action model compatibility
action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(self, batch: dict[str, Tensor]) -> Tensor:
_, embodied_tokens = self._extract_qwen_conditioning(batch)
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
return self.action_model.predict_action(embodied_tokens, state)
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
last_hidden = qwen_outputs.hidden_states[-1]
B, _, H = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=torch.float32
)
with torch.autocast("cuda", dtype=torch.float32):
# Cast embodied tokens to float32 for action model compatibility
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
@@ -190,22 +301,193 @@ class VLAJEPAPolicy(PreTrainedPolicy):
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _lerobot_to_native(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# Multi-frame: take the last frame as the "current" image
tensor = tensor[:, -1]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
num_views = len(image_keys)
has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch)
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = t_np.clip(0, 255).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
if ACTION in batch:
actions_tensor = batch[ACTION] # [B, chunk_size, action_dim]
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
if OBS_STATE in batch:
state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim]
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- Format Conversion: Native → LeRobot ----
def _native_to_lerobot(self, native_output: dict[str, Tensor]) -> tuple[Tensor, dict[str, float]]:
"""
Convert native VLA-JEPA output dict to LeRobot (loss, logs) format.
Native output:
{"action_loss": Tensor, "wm_loss": Tensor}
or {"wm_loss": Tensor} (video-only mode)
LeRobot output:
(total_loss: scalar Tensor, {"action_loss": float, "wm_loss": float, "loss": float})
"""
logs: dict[str, float] = {}
total_loss = torch.tensor(0.0, device=self.config.device)
if "action_loss" in native_output:
total_loss = total_loss + native_output["action_loss"]
logs["action_loss"] = native_output["action_loss"].detach().item()
if "wm_loss" in native_output:
wm_loss = native_output["wm_loss"]
logs["wm_loss"] = wm_loss.detach().item()
logs["loss"] = (
total_loss.detach().item()
if total_loss.item() != 0
else (logs.get("wm_loss", 0.0) + logs.get("action_loss", 0.0))
)
return total_loss, logs
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
loss, logs = self.model(batch)
return loss, {key: value.item() for key, value in logs.items()}
"""LeRobot train forward: convert → native forward → convert back."""
examples = self._lerobot_to_native(batch)
native_output = self.model.forward(examples)
return self._native_to_lerobot(native_output)
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self.model.predict_action(batch)
# Convert to native format
examples = self._lerobot_to_native(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
# Call native predict
actions_np = self.model.predict_action(batch_images, instructions, state_np)
# Convert back to tensor on the right device
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.model.predict_action(batch)
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()