From 31ddb8f493cbb2ccd0409f326439b62eccfed88e Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Tue, 9 Jun 2026 12:14:43 +0000 Subject: [PATCH] major refactor of the forward pass and model input conversion --- .../policies/vla_jepa/modeling_vla_jepa.py | 436 +++++++----------- .../policies/vla_jepa/qwen_interface.py | 9 +- tests/policies/vla_jepa/conftest.py | 6 +- tests/policies/vla_jepa/test_vla_jepa.py | 37 +- 4 files changed, 184 insertions(+), 304 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 1bd8305fb..e81fb4723 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging from collections import deque from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F # noqa: N812 @@ -53,12 +53,13 @@ class VLAJEPAModel(nn.Module): - DiT-B: flow-matching action head for future action prediction - V-JEPA: world model for video frame prediction - Input: List[dict] native format (same as original starVLA) - - "image": List[Tensor [C, H, W]] (float [0,1], multi-view images) - - "video": Tensor [V, T, C, H, W] (float [0,1]) - - "lang": str (task instruction) - - "action": Tensor [T, action_dim] (optional, training only) - - "state": Tensor [1, state_dim] (optional) + Inputs are batched tensors kept on the model device + - images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages) + - instructions: List[str] + - videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only) + - actions: Tensor [B, T, action_dim] (optional, training only) + - state: Tensor [B, 1, state_dim] (optional) + - action_is_pad: Tensor [B, T] (optional) """ def __init__(self, config: VLAJEPAConfig) -> None: @@ -159,166 +160,125 @@ class VLAJEPAModel(nn.Module): # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- - def forward(self, examples: list[dict]) -> dict[str, Tensor]: - """ - Native forward pass following original starVLA VLA_JEPA.forward. - - Args: - examples: List of per-sample dicts with keys: - "image" : List[Tensor [C, H, W]] (float [0,1]) — multi-view images - "video" : Tensor [V, T, C, H, W] (float [0,1]) - "lang" : str — task instruction - "action" : Tensor [T, action_dim] (optional) - "state" : Tensor [1, state_dim] (optional) - - Returns: - dict with "action_loss" and "wm_loss" keys (scalar Tensors). - """ - # Unpack native format (same pattern as original VLA_JEPA.py) - batch_images = [ex["image"] for ex in examples] # List[List[Tensor [C, H, W]]] - batch_videos = [ex["video"] for ex in examples] # List[Tensor [V, T, C, H, W]] - instructions = [ex["lang"] for ex in examples] # List[str] - has_action = "action" in examples[0] and examples[0]["action"] is not None - actions = [ex["action"] for ex in examples] if has_action else None - has_state = "state" in examples[0] and examples[0]["state"] is not None - state = [ex["state"] for ex in examples] if has_state else None - action_is_pad = ( - [ex["action_is_pad"] for ex in examples] - if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None - else None - ) - - # Stack videos: List[[V, T, 3, H, W]] -> [B, V, T, 3, H, W] (already channels-first) - batch_videos = torch.stack(batch_videos) - - # Adjust number of views for the world model: - # - fewer views than expected: duplicate the first view to fill up - # - more views than expected: keep only the first num_views_world_model views - num_views_world_model = self.config.jepa_tubelet_size - if batch_videos.shape[1] < num_views_world_model: - num_missing_views = num_views_world_model - batch_videos.shape[1] - first_view = batch_videos[:, :1].repeat(1, num_missing_views, 1, 1, 1, 1) - batch_videos = torch.cat([batch_videos, first_view], dim=1) - elif batch_videos.shape[1] > num_views_world_model: - batch_videos = batch_videos[:, :num_views_world_model] - - # ---- Step 1: QwenVL encode (same as original) ---- + def _encode_qwen( + self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool + ) -> tuple[Tensor, Tensor, Tensor | None]: + """Run Qwen and gather the embodied-action (and optionally action) token hidden states.""" qwen_inputs = self.qwen.build_inputs( - images=batch_images, + images=images, instructions=instructions, action_prompt=self.replace_prompt, embodied_prompt=self.embodied_replace_prompt, ) - - # Locate embodied-action tokens (always needed for action head) - embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id - embodied_indices = embodied_mask.nonzero(as_tuple=True) - - # Locate action tokens (only needed for world model predictor) - if self.config.enable_world_model: + input_ids = qwen_inputs["input_ids"] + embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True) + action_idx = None + if need_action_tokens: action_mask = torch.isin( - qwen_inputs["input_ids"], - torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), + input_ids, torch.tensor(self.action_token_ids, device=input_ids.device) ) - action_indices = action_mask.nonzero(as_tuple=True) + action_idx = action_mask.nonzero(as_tuple=True) device_type = next(self.parameters()).device.type - with torch.autocast(device_type=device_type, dtype=torch.bfloat16): last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] b, _, h = last_hidden.shape + embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h) + action_tokens = ( + last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h) + if action_idx is not None + else None + ) + return last_hidden, embodied_action_tokens, action_tokens - if self.config.enable_world_model: - action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h) + def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor: + """JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1].""" + # Match the world model's expected view count: pad with the first view, or trim extras. + num_views = self.config.jepa_tubelet_size + if videos.shape[1] < num_views: + missing = num_views - videos.shape[1] + videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1) + elif videos.shape[1] > num_views: + videos = videos[:, :num_views] - embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) + b, v, t_frames, c, h_img, w_img = videos.shape + flat = videos.reshape(b * v, t_frames, c, h_img, w_img) + # Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]). + video_pixels = self.video_processor( + videos=list(flat), + return_tensors="pt", + device=self.video_encoder.device, + do_rescale=False, + )["pixel_values_videos"] # [B*V, T, C, H, W] - # ---- Step 2+3: JEPA Encoder + Predictor ---- - device_wm = last_hidden.device - if not self.config.enable_world_model: - wm_loss = torch.tensor(0.0, device=device_wm) + with torch.no_grad(): + video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) + # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] + video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) + + tubelet_size = self.video_encoder.config.tubelet_size + # num_video_frames raw frames → t_enc_total temporal positions after tubelet compression + t_enc_total = self.config.num_video_frames // tubelet_size + if t_enc_total < 2: + return torch.zeros((), device=video_embeddings.device) + + # Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1 + t_enc_ctx = t_enc_total - 1 + tokens_per_frame = video_embeddings.shape[1] // t_enc_total + input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] + gt_states = video_embeddings[:, tokens_per_frame:, :] + + expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep + if action_tokens.shape[1] < expected_actions: + pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) + action_tokens = torch.cat([action_tokens, pad], dim=1) + + predicted_states = self.video_predictor( + input_states.float(), action_tokens[:, :expected_actions].float() + ) + return F.l1_loss(predicted_states, gt_states.float(), reduction="mean") + + def _action_loss( + self, + embodied_action_tokens: Tensor, + actions: Tensor, + state: Tensor | None, + action_is_pad: Tensor | None, + ) -> Tensor: + """Flow-matching action-head loss, repeated over `repeated_diffusion_steps`.""" + device_type = next(self.parameters()).device.type + with torch.autocast(device_type=device_type, dtype=torch.float32): + r = self.config.repeated_diffusion_steps + horizon = self.config.chunk_size + actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1) + embodied = embodied_action_tokens.repeat(r, 1, 1) + state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None + pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None + return self.action_model(embodied, actions_target, state_rep, pad_rep) + + def forward( + self, + images: list[list[Tensor]], + instructions: list[str], + videos: Tensor | None = None, + actions: Tensor | None = None, + state: Tensor | None = None, + action_is_pad: Tensor | None = None, + ) -> dict[str, Tensor]: + """Native forward: Qwen encode → optional world-model loss → optional action-head loss.""" + last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen( + images, instructions, need_action_tokens=self.config.enable_world_model + ) + + if self.config.enable_world_model: + wm_loss = self._world_model_loss(videos, action_tokens) else: - b, v, t_frames, c, h_img, w_img = batch_videos.shape - batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) + wm_loss = torch.zeros((), device=last_hidden.device) - # Fast (torchvision) video processor: pass GPU float [0,1] tensors + device so the - # resize/normalize stays on-device (no GPU->CPU->GPU roundtrip). do_rescale=False - # because the frames already arrive in [0, 1]. - video_pixels = self.video_processor( - videos=list(batch_videos_flat), - return_tensors="pt", - device=self.video_encoder.device, - do_rescale=False, - )["pixel_values_videos"] # [B*V, T, C, H, W] - - with torch.no_grad(): - video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) - # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] - video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) - - tubelet_size = self.video_encoder.config.tubelet_size - device_wm = video_embeddings.device - # num_video_frames raw frames → t_enc_total temporal positions after tubelet compression - t_enc_total = self.config.num_video_frames // tubelet_size - - if t_enc_total < 2: - wm_loss = torch.tensor(0.0, device=device_wm) - else: - # Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232): - # input_states: positions 0..T-2, gt_states: positions 1..T-1 - t_enc_ctx = t_enc_total - 1 - tokens_per_frame = video_embeddings.shape[1] // t_enc_total - - input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] - gt_states = video_embeddings[:, tokens_per_frame:, :] - - expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep - if action_tokens.shape[1] < expected_actions: - pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) - action_tokens = torch.cat([action_tokens, pad], dim=1) - - predicted_states = self.video_predictor( - input_states.float(), - action_tokens[:, :expected_actions].float(), - ) - - wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") - - if not has_action: + if actions is None: return {"wm_loss": wm_loss} - # ---- Step 4: Action Head ---- - with torch.autocast(device_type=device_type, dtype=torch.float32): - actions_tensor = torch.stack(actions).to( - device=last_hidden.device, dtype=torch.float32 - ) # [B, T_full, action_dim] - action_horizon = self.config.chunk_size - actions_target = actions_tensor[:, -action_horizon:, :] - - state_tensor = None - if state is not None: - state_tensor = torch.stack(state).to( - device=last_hidden.device, dtype=last_hidden.dtype - ) # [B, 1, state_dim] - - repeated_diffusion_steps = self.config.repeated_diffusion_steps - actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1) - embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1) - if state_tensor is not None: - state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1) - - action_is_pad_rep = None - if action_is_pad is not None: - pad_tensor = torch.stack( - [p.to(actions_target.device) for p in action_is_pad] - ) # [B, T_full] - pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon] - action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon] - - action_loss = self.action_model( - embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep - ) - + action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad) return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} # ---- Native predict_action (follows original VLA_JEPA.predict_action) ---- @@ -326,60 +286,24 @@ class VLAJEPAModel(nn.Module): @torch.no_grad() def predict_action( self, - batch_images: list[list[Tensor]], + images: list[list[Tensor]], instructions: list[str], state: Tensor | None = None, ) -> Tensor: - """ - Native action prediction following original VLA_JEPA.predict_action. - - Args: - batch_images: List of samples; each is List[Tensor [C, H, W]] (float [0,1], multi-view). - instructions: Task instructions, one per sample. - state: Optional [B, state_dim] tensor. - - Returns: - Tensor [B, action_horizon, action_dim] — predicted actions (on the model device). - """ + """Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors.""" if self.config.resize_images_to is not None: height, width = self.config.resize_images_to - # PIL BOX resampling ~= area-averaging downsample; F.interpolate(mode="area") - # is the on-GPU equivalent. Images stay float [0,1] (do_rescale=False downstream). - batch_images = [ - [ - F.interpolate(image[None], size=(height, width), mode="area")[0] - for image in sample_images - ] - for sample_images in batch_images + images = [ + [F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views] + for views in images ] - qwen_inputs = self.qwen.build_inputs( - images=batch_images, - instructions=instructions, - action_prompt=self.replace_prompt, - embodied_prompt=self.embodied_replace_prompt, + _, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False) + state = state.to(embodied_action_tokens.dtype) if state is not None else None + return self.action_model.predict_action( + embodied_action_tokens.float(), state.float() if state is not None else None ) - embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id - embodied_indices = embodied_mask.nonzero(as_tuple=True) - - device_type = next(self.parameters()).device.type - - with torch.autocast(device_type=device_type, dtype=torch.bfloat16): - last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] - b, _, h = last_hidden.shape - embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) - - state_tensor = None - if state is not None: - state_tensor = state.to(device=last_hidden.device, dtype=last_hidden.dtype) - - pred_actions = self.action_model.predict_action( - embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None - ) # [B, action_horizon, action_dim] - - return pred_actions - # ============================================================================ # LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format @@ -390,9 +314,9 @@ class VLAJEPAPolicy(PreTrainedPolicy): """ LeRobot adapter for VLA-JEPA. - Converts LeRobot's standard batch format (dict[str, Tensor]) to the native - VLA-JEPA format (List[dict]), calls the native model, and converts outputs - back to LeRobot format. + Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors + the native model expects (keeping everything on-device), calls the native model, and + converts outputs back to LeRobot format. """ config_class = VLAJEPAConfig @@ -419,9 +343,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): # ---- Format Conversion: LeRobot → Native ---- - def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]: - """ - Convert LeRobot batch format to native VLA-JEPA examples format. + def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]: + """Convert a LeRobot batch to the model's batched, on-device inputs. LeRobot format: batch = { @@ -431,49 +354,25 @@ class VLAJEPAPolicy(PreTrainedPolicy): "task": str | List[str], (optional instruction) } - Native format (List[dict]), all tensors kept on the batch device: - { - "image": List[Tensor [C, H, W]] (float [0,1]), # multi-view images per sample - "video": Tensor [V, T, C, H, W] (float [0,1]), - "lang": str, # task instruction - "action": Tensor [T, action_dim], # optional - "state": Tensor [1, state_dim], # optional - } + Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays + on the batch device; no per-sample shredding): `images` (per-sample, per-view list for + Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` / + `action_is_pad` when present. """ - # Determine batch size from the first image feature image_keys = list(self.config.image_features.keys()) if not image_keys: raise ValueError("VLAJEPA requires at least one image feature.") - first_key = image_keys[0] - first_tensor = batch[first_key] - batch_size = first_tensor.shape[0] + batch_size = batch[image_keys[0]].shape[0] - # ---- Collect images per sample ---- - # images_per_sample[b][v] = float [0,1] Tensor [C, H, W] for view v (kept on-device) - images_per_sample: list[list[Tensor]] = [[] for _ in range(batch_size)] + # Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages. + frames = [] for key in image_keys: - tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W] - if tensor.ndim == 5: - # observation_delta_indices = [0, 1, ..., num_video_frames-1] - # index 0 is the current observation (delta=0) - tensor = tensor[:, 0] - for b in range(batch_size): - images_per_sample[b].append(self.model.qwen.to_pixel_values(tensor[b])) + t = batch[key] + if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0) + t = t[:, 0] + frames.append(self.model.qwen.to_pixel_values(t)) + images = [[frame[b] for frame in frames] for b in range(batch_size)] - # ---- Collect videos per sample ---- - # Build video tensors: for each sample, stack views as [V, T, C, H, W] (float [0,1], on-device) - videos_per_sample = [] - for b in range(batch_size): - sample_views = [] - for k in image_keys: - t = batch[k][b] # [C, H, W] or [T, C, H, W] - if t.ndim == 3: - t = t.unsqueeze(0) # [1, C, H, W] - sample_views.append(self.model.qwen.to_pixel_values(t)) - # Stack views: [V, T, C, H, W] - videos_per_sample.append(torch.stack(sample_views, dim=0)) - - # ---- Collect instructions ---- tasks = batch.get("task") if tasks is None: instructions = ["Execute the robot action."] * batch_size @@ -482,52 +381,32 @@ class VLAJEPAPolicy(PreTrainedPolicy): else: instructions = list(tasks) - # ---- Collect actions (training only) ---- - actions_list = None - action_is_pad_list = None - actions_tensor = batch.get(ACTION) - if actions_tensor is not None: - if actions_tensor.ndim == 2: - actions_tensor = actions_tensor.unsqueeze(1) - actions_list = [actions_tensor[b].detach().float() for b in range(batch_size)] - action_is_pad_tensor = batch.get("action_is_pad") - if action_is_pad_tensor is not None: - action_is_pad_list = [action_is_pad_tensor[b].detach() for b in range(batch_size)] + inputs: dict[str, Any] = {"images": images, "instructions": instructions} - # ---- Collect state ---- - state_list = None - state_tensor = batch.get(OBS_STATE) - if state_tensor is not None: - if state_tensor.ndim > 2: - state_tensor = state_tensor[:, -1, :] - if state_tensor.ndim == 2: - state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim] - state_list = [state_tensor[b].detach().float() for b in range(batch_size)] + # Videos [B, V, T, C, H, W] - only assembled when the world model consumes them. + if self.model.config.enable_world_model: + views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys] + inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1)) - # ---- Assemble native examples ---- - examples = [] - for b in range(batch_size): - example = { - "image": images_per_sample[b], - "video": videos_per_sample[b], - "lang": instructions[b], - } - if actions_list is not None: - example["action"] = actions_list[b] - if action_is_pad_list is not None: - example["action_is_pad"] = action_is_pad_list[b] - if state_list is not None: - example["state"] = state_list[b] - examples.append(example) + actions = batch.get(ACTION) + if actions is not None: + inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float() + if (pad := batch.get("action_is_pad")) is not None: + inputs["action_is_pad"] = pad - return examples + state = batch.get(OBS_STATE) + if state is not None: + if state.ndim > 2: + state = state[:, -1, :] + inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim] + + return inputs # ---- LeRobot Policy Interface ---- def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """LeRobot train forward: convert → native forward → aggregate losses.""" - examples = self._prepare_model_inputs(batch) - native_output = self.model.forward(examples) + native_output = self.model.forward(**self._prepare_model_inputs(batch)) ref = next(iter(native_output.values())) zero = torch.zeros((), device=ref.device, dtype=ref.dtype) @@ -545,15 +424,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): self.eval() self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - examples = self._prepare_model_inputs(batch) - batch_images = [ex["image"] for ex in examples] - instructions = [ex["lang"] for ex in examples] - - state = None - if "state" in examples[0] and examples[0]["state"] is not None: - state = torch.stack([ex["state"] for ex in examples]) - - actions = self.model.predict_action(batch_images, instructions, state) + inputs = self._prepare_model_inputs(batch) + actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state")) return actions.to(device=self.config.device, dtype=torch.float32) @torch.no_grad() diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index d24e6aaaa..bcad1f558 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -121,8 +121,13 @@ class Qwen3VLInterface(torch.nn.Module): normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it through as float and let the processors normalize (no rescale, no uint8 quantization). A single channel is expanded to 3 to match the RGB processors. + + Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W], + [T, C, H, W], [B, V, T, C, H, W], ... """ image = image_tensor.detach().float() - if image.ndim == 3 and image.shape[0] == 1: - image = image.repeat(3, 1, 1) + if image.shape[-3] == 1: + repeats = [1] * image.ndim + repeats[-3] = 3 + image = image.repeat(*repeats) return image diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index 799802d5b..dd40ca9ea 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -215,8 +215,10 @@ class _FakeQwenInterface(nn.Module): @staticmethod def to_pixel_values(image_tensor: Tensor) -> Tensor: image = image_tensor.detach().float() - if image.ndim == 3 and image.shape[0] == 1: - image = image.repeat(3, 1, 1) + if image.shape[-3] == 1: + repeats = [1] * image.ndim + repeats[-3] = 3 + image = image.repeat(*repeats) return image diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index b9bc398a2..a3e24a660 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -212,40 +212,41 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) - examples = policy._prepare_model_inputs(make_train_batch()) + inputs = policy._prepare_model_inputs(make_train_batch()) - assert len(examples) == BATCH_SIZE - for ex in examples: - assert set(ex) >= {"image", "video", "lang", "action", "state"} - assert len(ex["image"]) == 1 - assert isinstance(ex["image"][0], torch.Tensor) and ex["image"][0].dtype == torch.float32 - assert ex["image"][0].ndim == 3 # [C, H, W] - assert isinstance(ex["video"], torch.Tensor) - assert ex["video"].ndim == 5 and ex["video"].dtype == torch.float32 # [V, T, C, H, W] - assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM) - assert ex["state"].shape == (1, STATE_DIM) + assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"} + # images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages) + assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1 + img = inputs["images"][0][0] + assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3 + assert len(inputs["instructions"]) == BATCH_SIZE + # videos: batched [B, V, T, C, H, W] float + assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE + assert inputs["videos"].dtype == torch.float32 + assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM) def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) - for ex in policy._prepare_model_inputs(make_inference_batch()): - assert "action" not in ex - assert "image" in ex and "video" in ex and "lang" in ex + inputs = policy._prepare_model_inputs(make_inference_batch()) + assert "actions" not in inputs and "action_is_pad" not in inputs + assert {"images", "instructions", "state"} <= set(inputs) def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() del batch["task"] - examples = policy._prepare_model_inputs(batch) - assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples) + instructions = policy._prepare_model_inputs(batch)["instructions"] + assert all(isinstance(s, str) and len(s) > 0 for s in instructions) def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() batch["task"] = "open the drawer" - assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch)) + assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None: @@ -254,7 +255,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() del batch[OBS_STATE] - assert all("state" not in ex for ex in policy._prepare_model_inputs(batch)) + assert "state" not in policy._prepare_model_inputs(batch) # ---------------------------------------------------------------------------