diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3609cc7c3..4785c7d3d 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -56,6 +56,7 @@ from .sac.configuration_sac import SACConfig from .smolvla.configuration_smolvla import SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig from .utils import validate_visual_features_consistency +from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig from .vqbet.configuration_vqbet import VQBeTConfig from .wall_x.configuration_wall_x import WallXConfig from .xvla.configuration_xvla import XVLAConfig @@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .eo1.modeling_eo1 import EO1Policy return EO1Policy + elif name == "vla_jepa": + from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy + + return VLAJEPAPolicy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return WallXConfig(**kwargs) elif policy_type == "eo1": return EO1Config(**kwargs) + elif policy_type == "vla_jepa": + return VLAJEPAConfig(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -406,6 +413,7 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, EO1Config): from .eo1.processor_eo1 import make_eo1_pre_post_processors @@ -414,6 +422,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, VLAJEPAConfig): + from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors + + processors = make_vla_jepa_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: try: processors = _make_processors_from_policy_config( diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 2ff34e071..8e3cc94a9 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional import torch import torch.nn.functional as F @@ -203,7 +202,9 @@ class VLAJEPAActionHead(nn.Module): else None ) self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size) - self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size) + self.position_embedding = nn.Embedding( + config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size + ) self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta) def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: @@ -268,7 +269,9 @@ class VLAJEPAActionHead(nn.Module): for step in range(self.num_inference_timesteps): t_cont = step / float(max(self.num_inference_timesteps, 1)) t_value = int(t_cont * self.config.action_num_timestep_buckets) - timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long) + timesteps = torch.full( + (batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long + ) hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps) pred = self.model( hidden_states=hidden_states, diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 62d4f065d..5bc25fe32 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -27,12 +27,14 @@ class VLAJEPAConfig(PreTrainedConfig): jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" tokenizer_padding_side: str = "left" - prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}." + prompt_template: str = ( + "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}." + ) special_action_token: str = "<|action_{}|>" embodied_action_token: str = "<|embodied_action|>" action_dim: int = 7 - state_dim: int = 7 + state_dim: int = 8 future_action_window_size: int = 15 past_action_window_size: int = 0 num_action_tokens_per_timestep: int = 4 @@ -42,7 +44,7 @@ class VLAJEPAConfig(PreTrainedConfig): action_hidden_size: int = 1024 action_model_type: str = "DiT-B" action_num_layers: int = 12 - action_num_heads: int = 12 + action_num_heads: int = 16 action_attention_head_dim: int = 64 action_dropout: float = 0.1 action_num_timestep_buckets: int = 1000 diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 429ddf96f..203460199 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -3,29 +3,60 @@ from __future__ import annotations from collections import deque from pathlib import Path +import numpy as np import torch import torch.nn.functional as F +from PIL import Image from torch import Tensor, nn from transformers import AutoModel, AutoVideoProcessor from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.utils import populate_queues -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE from .action_head import VLAJEPAActionHead from .configuration_vla_jepa import VLAJEPAConfig from .qwen_interface import Qwen3VLInterface from .world_model import ActionConditionedVideoPredictor +# ============================================================================ +# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation +# ============================================================================ + class VLAJEPAModel(nn.Module): + """ + Native VLA-JEPA model following the original starVLA VLA_JEPA.py. + + Components: + - Qwen3-VL: vision-language backbone for fused embeddings + - DiT-B: flow-matching action head for future action prediction + - V-JEPA: world model for video frame prediction + + Input: List[dict] native format (same as original starVLA) + - "image": List[PIL.Image] (multi-view images) + - "video": np.ndarray [V, T, H, W, 3] + - "lang": str (task instruction) + - "action": np.ndarray [T, action_dim] (optional, training only) + - "state": np.ndarray [1, state_dim] (optional) + """ + def __init__(self, config: VLAJEPAConfig) -> None: super().__init__() self.config = config + + # Vision-language backbone self.qwen = Qwen3VLInterface(config) - self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer() + + # Tokenizer expansion for special action tokens + self.action_tokens, self.action_token_ids, self.embodied_action_token_id = ( + self.qwen.expand_tokenizer() + ) + + # Action head (flow-matching DiT) self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size) + # JEPA world model components self.video_encoder = AutoModel.from_pretrained( config.jepa_encoder_name, torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), @@ -40,144 +71,224 @@ class VLAJEPAModel(nn.Module): mlp_ratio=config.predictor_mlp_ratio, num_action_tokens_per_step=config.num_action_tokens_per_timestep, ) + + # Build prompt placeholders (same as original) self.replace_prompt = "".join( token * self.config.num_action_tokens_per_timestep for token in self.action_tokens[: self.config.num_video_frames - 1] ) - self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction + self.embodied_replace_prompt = ( + self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction + ) - def _collect_images(self, batch: dict[str, Tensor]) -> list[list]: - sample_key = self.config.image_features[0] - batch_size = batch[sample_key].shape[0] - images = [[] for _ in range(batch_size)] - for key in self.config.image_features: - tensor = batch[key] - if tensor.ndim == 5: - tensor = tensor[:, -1] - for idx in range(batch_size): - images[idx].append(self.qwen.tensor_to_pil(tensor[idx])) - return images + # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- - def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor: - first_key = self.config.image_features[0] - source = batch[first_key] - if source.ndim == 4: - source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1) - elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames: - pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1) - source = torch.cat([source, pad], dim=1) - elif source.ndim == 5: - source = source[:, -self.config.num_video_frames :] - else: - raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}") - return source + def forward(self, examples: list[dict]) -> dict[str, Tensor]: + """ + Native forward pass following original starVLA VLA_JEPA.forward. - def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]: - tasks = batch.get("task") - if tasks is None: - return ["Execute the robot action."] * next(iter(batch.values())).shape[0] - if isinstance(tasks, str): - return [tasks] - return list(tasks) + Args: + examples: List of per-sample dicts with keys: + "image" : List[PIL.Image] — multi-view images + "video" : np.ndarray [V, T, H, W, 3] + "lang" : str — task instruction + "action" : np.ndarray [T, action_dim] (optional) + "state" : np.ndarray [1, state_dim] (optional) - def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - images = self._collect_images(batch) - tasks = self._get_tasks(batch) + Returns: + dict with "action_loss" and "wm_loss" keys (scalar Tensors). + """ + # Unpack native format (same pattern as original VLA_JEPA.py) + batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]] + batch_videos = [ex["video"] for ex in examples] # List[np.ndarray] + instructions = [ex["lang"] for ex in examples] # List[str] + has_action = "action" in examples[0] and examples[0]["action"] is not None + actions = [ex["action"] for ex in examples] if has_action else None + has_state = "state" in examples[0] and examples[0]["state"] is not None + state = [ex["state"] for ex in examples] if has_state else None + + # Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W] + batch_videos = np.stack(batch_videos) + batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W] + + # ---- Step 1: QwenVL encode (same as original) ---- qwen_inputs = self.qwen.build_inputs( - images=images, - instructions=tasks, + images=batch_images, + instructions=instructions, action_prompt=self.replace_prompt, embodied_prompt=self.embodied_replace_prompt, ) - outputs = self.qwen.model( - **qwen_inputs, - output_hidden_states=True, - output_attentions=False, - return_dict=True, - ) - hidden = outputs.hidden_states[-1] + + # Locate action and embodied-action tokens in the tokenized sequence action_mask = torch.isin( qwen_inputs["input_ids"], torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), ) action_indices = action_mask.nonzero(as_tuple=True) - action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1]) embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id embodied_indices = embodied_mask.nonzero(as_tuple=True) - embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1]) - return action_tokens, embodied_tokens - def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None: - if OBS_STATE not in batch: - return None - state = batch[OBS_STATE] - if state.ndim > 2: - state = state[:, -1, :] - return state.to(device=device, dtype=dtype) + with torch.autocast("cuda", dtype=torch.bfloat16): + qwen_outputs = self.qwen.model( + **qwen_inputs, + output_hidden_states=True, + output_attentions=False, + return_dict=True, + ) + last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H] + B, _, H = last_hidden.shape - def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor: - actions = batch[ACTION] - if actions.ndim == 2: - actions = actions.unsqueeze(1) - horizon = self.config.future_action_window_size + 1 - if actions.shape[1] < horizon: - pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1) - actions = torch.cat([actions, pad], dim=1) - return actions[:, -horizon:].to(device=device, dtype=dtype) + action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H) - def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor: - processed = [] - for sample in video_tensor: - processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"] - processed.append(processed_sample) - pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device) - return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values) + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) - def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None: - if not self.config.enable_world_model: - return None - video_tensor = self._collect_videos(batch) - video_features = self._encode_video(video_tensor) - batch_size = video_tensor.shape[0] - num_frames = video_tensor.shape[1] - tokens_per_frame = video_features.shape[1] // num_frames - video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1) - input_states = video_features[:, :-1] - gt_states = video_features[:, 1:] + # ---- Step 2: JEPA Encoder (same as original) ---- + B, V, T_frames, C, H_img, W_img = batch_videos.shape + batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img) - expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep - if action_tokens.shape[1] < expected_tokens: - pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1) - action_tokens = torch.cat([action_tokens, pad], dim=1) - action_tokens = action_tokens[:, :expected_tokens] - action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1) - pred_states = self.video_predictor(input_states, action_tokens) - return F.l1_loss(pred_states, gt_states, reduction="mean") + video_pixels = [] + for i in range(B * V): + video_pixels.append( + self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ + "pixel_values_videos" + ].to(self.video_encoder.device) + ) + video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W] - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: - action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch) - state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype) - target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype) - action_loss = self.action_model(embodied_tokens, target_actions, state) + with torch.no_grad(): + video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) + # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] + video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2) - wm_loss = self._compute_world_model_loss(batch, action_tokens) - total_loss = action_loss - logs = {"action_loss": action_loss.detach()} - if wm_loss is not None: - total_loss = total_loss + self.config.world_model_loss_weight * wm_loss - logs["wm_loss"] = wm_loss.detach() - logs["loss"] = total_loss.detach() - return total_loss, logs + # ---- Step 3: JEPA Predictor (same as original) ---- + tubelet_size = self.video_encoder.config.tubelet_size + T_enc = T_frames // tubelet_size + device_wm = video_embeddings.device + + if T_enc < 2: + # Not enough frames for JEPA prediction (need at least 2 encoded frames) + wm_loss = torch.tensor(0.0, device=device_wm) + else: + tokens_per_frame = video_embeddings.shape[1] // T_enc + + # input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D] + # gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D] + input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :] + gt_states = video_embeddings[:, tokens_per_frame:, :] + D_emb = input_states.shape[-1] + + # Reshape to 4D for ActionConditionedVideoPredictor: + # [B, (T-1)*tokens, D] → [B, T-1, tokens, D] + input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb) + + # Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D] + expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep + if action_tokens.shape[1] < expected_actions: + pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) + action_tokens = torch.cat([action_tokens, pad], dim=1) + act_4d = action_tokens[:, :expected_actions].view( + B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1 + ) + + # Cast to float32 for predictor (Linear layers are float32) + pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) + predicted_states = pred_4d.reshape(B, -1, D_emb) + + wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") + + if not has_action: + return {"wm_loss": wm_loss} + + # ---- Step 4: Action Head (same as original) ---- + with torch.autocast("cuda", dtype=torch.float32): + actions_tensor = torch.tensor( + np.array(actions), device=last_hidden.device, dtype=torch.float32 + ) # [B, T_full, action_dim] + action_horizon = self.config.future_action_window_size + 1 + actions_target = actions_tensor[:, -action_horizon:, :] + + state_tensor = None + if state is not None: + state_tensor = torch.tensor( + np.array(state), device=last_hidden.device, dtype=torch.float32 + ) # [B, 1, state_dim] + + # Cast embodied tokens to float32 for action model compatibility + action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor) + + return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} + + # ---- Native predict_action (follows original VLA_JEPA.predict_action) ---- @torch.no_grad() - def predict_action(self, batch: dict[str, Tensor]) -> Tensor: - _, embodied_tokens = self._extract_qwen_conditioning(batch) - state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype) - return self.action_model.predict_action(embodied_tokens, state) + def predict_action( + self, + batch_images: list[list[Image.Image]], + instructions: list[str], + state: np.ndarray | None = None, + ) -> np.ndarray: + """ + Native action prediction following original VLA_JEPA.predict_action. + + Args: + batch_images: List of samples; each is List[PIL.Image] (multi-view). + instructions: Task instructions, one per sample. + state: Optional [B, state_dim] numpy array. + + Returns: + np.ndarray [B, action_horizon, action_dim] — predicted actions. + """ + qwen_inputs = self.qwen.build_inputs( + images=batch_images, + instructions=instructions, + action_prompt=self.replace_prompt, + embodied_prompt=self.embodied_replace_prompt, + ) + + embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id + embodied_indices = embodied_mask.nonzero(as_tuple=True) + + with torch.autocast("cuda", dtype=torch.bfloat16): + qwen_outputs = self.qwen.model( + **qwen_inputs, + output_hidden_states=True, + output_attentions=False, + return_dict=True, + ) + last_hidden = qwen_outputs.hidden_states[-1] + B, _, H = last_hidden.shape + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) + + state_tensor = None + if state is not None: + state_tensor = torch.from_numpy(np.array(state)).to( + device=last_hidden.device, dtype=torch.float32 + ) + + with torch.autocast("cuda", dtype=torch.float32): + # Cast embodied tokens to float32 for action model compatibility + pred_actions = self.action_model.predict_action( + embodied_action_tokens.float(), state_tensor + ) # [B, action_horizon, action_dim] + + return pred_actions.detach().cpu().numpy() + + +# ============================================================================ +# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format +# ============================================================================ class VLAJEPAPolicy(PreTrainedPolicy): + """ + LeRobot adapter for VLA-JEPA. + + Converts LeRobot's standard batch format (dict[str, Tensor]) to the native + VLA-JEPA format (List[dict]), calls the native model, and converts outputs + back to LeRobot format. + """ + config_class = VLAJEPAConfig name = "vla_jepa" @@ -190,22 +301,193 @@ class VLAJEPAPolicy(PreTrainedPolicy): def reset(self) -> None: self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)} + # ---- Format Conversion: LeRobot → Native ---- + + def _lerobot_to_native(self, batch: dict[str, Tensor]) -> list[dict]: + """ + Convert LeRobot batch format to native VLA-JEPA examples format. + + LeRobot format: + batch = { + "observation.images.": Tensor [B, C, H, W] or [B, T, C, H, W], + "observation.state": Tensor [B, state_dim] or [B, T, state_dim], + "action": Tensor [B, chunk_size, action_dim], (training only) + "task": str | List[str], (optional instruction) + } + + Native format (List[dict]): + { + "image": List[PIL.Image], # multi-view images per sample + "video": np.ndarray [V, T, H, W, 3], + "lang": str, # task instruction + "action": np.ndarray [T, action_dim], # optional + "state": np.ndarray [1, state_dim], # optional + } + """ + # Determine batch size from the first image feature + image_keys = list(self.config.image_features.keys()) + if not image_keys: + raise ValueError("VLAJEPA requires at least one image feature.") + first_key = image_keys[0] + first_tensor = batch[first_key] + batch_size = first_tensor.shape[0] + + # ---- Collect images per sample ---- + # images_per_sample[b][v] = PIL.Image for view v + images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)] + for key in image_keys: + tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W] + if tensor.ndim == 5: + # Multi-frame: take the last frame as the "current" image + tensor = tensor[:, -1] + for b in range(batch_size): + images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b])) + + # ---- Collect videos per sample ---- + # Build video arrays: for each sample, stack views as [V, T, H, W, 3] + num_views = len(image_keys) + has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch) + # Check whether any image feature has a time dimension + video_source = None + for k in image_keys: + if k in batch: + video_source = batch[k] # Use first available for shape inspection + break + + if video_source is None: + raise ValueError("No image data found in batch for video construction.") + + videos_per_sample = [] + for b in range(batch_size): + sample_views = [] + for k in image_keys: + t = batch[k][b] # [C, H, W] or [T, C, H, W] + if t.ndim == 3: + t = t.unsqueeze(0) # [1, C, H, W] + # Convert to [T, H, W, 3] numpy + t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy() + # Clamp to [0, 255] + if t_np.max() <= 1.0: + t_np = t_np * 255.0 + t_np = t_np.clip(0, 255).astype(np.uint8) + sample_views.append(t_np) + # Stack views: [V, T, H, W, 3] + videos_per_sample.append(np.stack(sample_views, axis=0)) + + # ---- Collect instructions ---- + tasks = batch.get("task") + if tasks is None: + instructions = ["Execute the robot action."] * batch_size + elif isinstance(tasks, str): + instructions = [tasks] * batch_size + else: + instructions = list(tasks) + + # ---- Collect actions (training only) ---- + actions_list = None + if ACTION in batch: + actions_tensor = batch[ACTION] # [B, chunk_size, action_dim] + if actions_tensor.ndim == 2: + actions_tensor = actions_tensor.unsqueeze(1) + actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] + + # ---- Collect state ---- + state_list = None + if OBS_STATE in batch: + state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim] + if state_tensor.ndim > 2: + state_tensor = state_tensor[:, -1, :] + if state_tensor.ndim == 2: + state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim] + state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] + + # ---- Assemble native examples ---- + examples = [] + for b in range(batch_size): + example = { + "image": images_per_sample[b], + "video": videos_per_sample[b], + "lang": instructions[b], + } + if actions_list is not None: + example["action"] = actions_list[b] + if state_list is not None: + example["state"] = state_list[b] + examples.append(example) + + return examples + + # ---- Format Conversion: Native → LeRobot ---- + + def _native_to_lerobot(self, native_output: dict[str, Tensor]) -> tuple[Tensor, dict[str, float]]: + """ + Convert native VLA-JEPA output dict to LeRobot (loss, logs) format. + + Native output: + {"action_loss": Tensor, "wm_loss": Tensor} + or {"wm_loss": Tensor} (video-only mode) + + LeRobot output: + (total_loss: scalar Tensor, {"action_loss": float, "wm_loss": float, "loss": float}) + """ + logs: dict[str, float] = {} + total_loss = torch.tensor(0.0, device=self.config.device) + + if "action_loss" in native_output: + total_loss = total_loss + native_output["action_loss"] + logs["action_loss"] = native_output["action_loss"].detach().item() + + if "wm_loss" in native_output: + wm_loss = native_output["wm_loss"] + logs["wm_loss"] = wm_loss.detach().item() + + logs["loss"] = ( + total_loss.detach().item() + if total_loss.item() != 0 + else (logs.get("wm_loss", 0.0) + logs.get("action_loss", 0.0)) + ) + + return total_loss, logs + + # ---- LeRobot Policy Interface ---- + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - loss, logs = self.model(batch) - return loss, {key: value.item() for key, value in logs.items()} + """LeRobot train forward: convert → native forward → convert back.""" + examples = self._lerobot_to_native(batch) + native_output = self.model.forward(examples) + return self._native_to_lerobot(native_output) + + def get_optim_params(self) -> dict: + return self.model.parameters() @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002 + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """LeRobot inference: convert → native predict → return as Tensor.""" self.eval() self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - return self.model.predict_action(batch) + + # Convert to native format + examples = self._lerobot_to_native(batch) + batch_images = [ex["image"] for ex in examples] + instructions = [ex["lang"] for ex in examples] + + state_np = None + if "state" in examples[0] and examples[0]["state"] is not None: + state_np = np.stack([ex["state"] for ex in examples]) + + # Call native predict + actions_np = self.model.predict_action(batch_images, instructions, state_np) + + # Convert back to tensor on the right device + return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32) @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002 + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """LeRobot select_action with action queue caching.""" self.eval() self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) if len(self._queues[ACTION]) == 0: - actions = self.model.predict_action(batch) + actions = self.predict_action_chunk(batch) self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) return self._queues[ACTION].popleft()