major refactor of the forward pass and model input conversion

This commit is contained in:
Maxime Ellerbach
2026-06-09 12:14:43 +00:00
committed by Maximellerbach
parent 877847c90e
commit 31ddb8f493
4 changed files with 184 additions and 304 deletions
+154 -282
View File
@@ -17,7 +17,7 @@ from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
@@ -53,12 +53,13 @@ class VLAJEPAModel(nn.Module):
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[Tensor [C, H, W]] (float [0,1], multi-view images)
- "video": Tensor [V, T, C, H, W] (float [0,1])
- "lang": str (task instruction)
- "action": Tensor [T, action_dim] (optional, training only)
- "state": Tensor [1, state_dim] (optional)
Inputs are batched tensors kept on the model device
- images: List[List[Tensor [C, H, W]]] (float [0,1]) — per sample, per view (Qwen messages)
- instructions: List[str]
- videos: Tensor [B, V, T, C, H, W] (float [0,1], world model only)
- actions: Tensor [B, T, action_dim] (optional, training only)
- state: Tensor [B, 1, state_dim] (optional)
- action_is_pad: Tensor [B, T] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
@@ -159,166 +160,125 @@ class VLAJEPAModel(nn.Module):
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[Tensor [C, H, W]] (float [0,1]) — multi-view images
"video" : Tensor [V, T, C, H, W] (float [0,1])
"lang" : str — task instruction
"action" : Tensor [T, action_dim] (optional)
"state" : Tensor [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[Tensor [C, H, W]]]
batch_videos = [ex["video"] for ex in examples] # List[Tensor [V, T, C, H, W]]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: List[[V, T, 3, H, W]] -> [B, V, T, 3, H, W] (already channels-first)
batch_videos = torch.stack(batch_videos)
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = batch_videos[:, :1].repeat(1, num_missing_views, 1, 1, 1, 1)
batch_videos = torch.cat([batch_videos, first_view], dim=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
def _encode_qwen(
self, images: list[list[Tensor]], instructions: list[str], *, need_action_tokens: bool
) -> tuple[Tensor, Tensor, Tensor | None]:
"""Run Qwen and gather the embodied-action (and optionally action) token hidden states."""
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
images=images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
input_ids = qwen_inputs["input_ids"]
embodied_idx = (input_ids == self.embodied_action_token_id).nonzero(as_tuple=True)
action_idx = None
if need_action_tokens:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
input_ids, torch.tensor(self.action_token_ids, device=input_ids.device)
)
action_indices = action_mask.nonzero(as_tuple=True)
action_idx = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
action_tokens = (
last_hidden[action_idx[0], action_idx[1], :].view(b, -1, h)
if action_idx is not None
else None
)
return last_hidden, embodied_action_tokens, action_tokens
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
def _world_model_loss(self, videos: Tensor, action_tokens: Tensor) -> Tensor:
"""JEPA encode + predictor L1 loss. `videos` is [B, V, T, C, H, W] float in [0, 1]."""
# Match the world model's expected view count: pad with the first view, or trim extras.
num_views = self.config.jepa_tubelet_size
if videos.shape[1] < num_views:
missing = num_views - videos.shape[1]
videos = torch.cat([videos, videos[:, :1].repeat(1, missing, 1, 1, 1, 1)], dim=1)
elif videos.shape[1] > num_views:
videos = videos[:, :num_views]
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
b, v, t_frames, c, h_img, w_img = videos.shape
flat = videos.reshape(b * v, t_frames, c, h_img, w_img)
# Fast (torchvision) video processor on-device, do_rescale=False (frames already in [0, 1]).
video_pixels = self.video_processor(
videos=list(flat),
return_tensors="pt",
device=self.video_encoder.device,
do_rescale=False,
)["pixel_values_videos"] # [B*V, T, C, H, W]
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
return torch.zeros((), device=video_embeddings.device)
# Shift-by-one JEPA split: input_states = positions 0..T-2, gt_states = positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(), action_tokens[:, :expected_actions].float()
)
return F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
def _action_loss(
self,
embodied_action_tokens: Tensor,
actions: Tensor,
state: Tensor | None,
action_is_pad: Tensor | None,
) -> Tensor:
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.float32):
r = self.config.repeated_diffusion_steps
horizon = self.config.chunk_size
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
embodied = embodied_action_tokens.repeat(r, 1, 1)
state_rep = state.to(embodied_action_tokens.dtype).repeat(r, 1, 1) if state is not None else None
pad_rep = action_is_pad[:, -horizon:].repeat(r, 1) if action_is_pad is not None else None
return self.action_model(embodied, actions_target, state_rep, pad_rep)
def forward(
self,
images: list[list[Tensor]],
instructions: list[str],
videos: Tensor | None = None,
actions: Tensor | None = None,
state: Tensor | None = None,
action_is_pad: Tensor | None = None,
) -> dict[str, Tensor]:
"""Native forward: Qwen encode → optional world-model loss → optional action-head loss."""
last_hidden, embodied_action_tokens, action_tokens = self._encode_qwen(
images, instructions, need_action_tokens=self.config.enable_world_model
)
if self.config.enable_world_model:
wm_loss = self._world_model_loss(videos, action_tokens)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
wm_loss = torch.zeros((), device=last_hidden.device)
# Fast (torchvision) video processor: pass GPU float [0,1] tensors + device so the
# resize/normalize stays on-device (no GPU->CPU->GPU roundtrip). do_rescale=False
# because the frames already arrive in [0, 1].
video_pixels = self.video_processor(
videos=list(batch_videos_flat),
return_tensors="pt",
device=self.video_encoder.device,
do_rescale=False,
)["pixel_values_videos"] # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
if actions is None:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.stack(actions).to(
device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.stack(state).to(
device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[p.to(actions_target.device) for p in action_is_pad]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
action_loss = self._action_loss(embodied_action_tokens, actions, state, action_is_pad)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@@ -326,60 +286,24 @@ class VLAJEPAModel(nn.Module):
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Tensor]],
images: list[list[Tensor]],
instructions: list[str],
state: Tensor | None = None,
) -> Tensor:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[Tensor [C, H, W]] (float [0,1], multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] tensor.
Returns:
Tensor [B, action_horizon, action_dim] — predicted actions (on the model device).
"""
"""Predict an action chunk. `images` is per-sample, per-view float [0,1] [C, H, W] tensors."""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
# PIL BOX resampling ~= area-averaging downsample; F.interpolate(mode="area")
# is the on-GPU equivalent. Images stay float [0,1] (do_rescale=False downstream).
batch_images = [
[
F.interpolate(image[None], size=(height, width), mode="area")[0]
for image in sample_images
]
for sample_images in batch_images
images = [
[F.interpolate(img[None], size=(height, width), mode="area")[0] for img in views]
for views in images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
_, embodied_action_tokens, _ = self._encode_qwen(images, instructions, need_action_tokens=False)
state = state.to(embodied_action_tokens.dtype) if state is not None else None
return self.action_model.predict_action(
embodied_action_tokens.float(), state.float() if state is not None else None
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = state.to(device=last_hidden.device, dtype=last_hidden.dtype)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
@@ -390,9 +314,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the batched tensors
the native model expects (keeping everything on-device), calls the native model, and
converts outputs back to LeRobot format.
"""
config_class = VLAJEPAConfig
@@ -419,9 +343,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Any]:
"""Convert a LeRobot batch to the model's batched, on-device inputs.
LeRobot format:
batch = {
@@ -431,49 +354,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
"task": str | List[str], (optional instruction)
}
Native format (List[dict]), all tensors kept on the batch device:
{
"image": List[Tensor [C, H, W]] (float [0,1]), # multi-view images per sample
"video": Tensor [V, T, C, H, W] (float [0,1]),
"lang": str, # task instruction
"action": Tensor [T, action_dim], # optional
"state": Tensor [1, state_dim], # optional
}
Returns the kwargs for `VLAJEPAModel.forward` / `.predict_action` (everything stays
on the batch device; no per-sample shredding): `images` (per-sample, per-view list for
Qwen messages), `instructions`, and the batched `videos` / `actions` / `state` /
`action_is_pad` when present.
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
batch_size = batch[image_keys[0]].shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = float [0,1] Tensor [C, H, W] for view v (kept on-device)
images_per_sample: list[list[Tensor]] = [[] for _ in range(batch_size)]
# Current-frame image per view ([B, C, H, W]); regroup per sample for Qwen messages.
frames = []
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.to_pixel_values(tensor[b]))
t = batch[key]
if t.ndim == 5: # [B, T, C, H, W] -> current observation (delta=0)
t = t[:, 0]
frames.append(self.model.qwen.to_pixel_values(t))
images = [[frame[b] for frame in frames] for b in range(batch_size)]
# ---- Collect videos per sample ----
# Build video tensors: for each sample, stack views as [V, T, C, H, W] (float [0,1], on-device)
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
sample_views.append(self.model.qwen.to_pixel_values(t))
# Stack views: [V, T, C, H, W]
videos_per_sample.append(torch.stack(sample_views, dim=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
@@ -482,52 +381,32 @@ class VLAJEPAPolicy(PreTrainedPolicy):
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().float() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach() for b in range(batch_size)]
inputs: dict[str, Any] = {"images": images, "instructions": instructions}
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().float() for b in range(batch_size)]
# Videos [B, V, T, C, H, W] - only assembled when the world model consumes them.
if self.model.config.enable_world_model:
views = [batch[k].unsqueeze(1) if batch[k].ndim == 4 else batch[k] for k in image_keys]
inputs["videos"] = self.model.qwen.to_pixel_values(torch.stack(views, dim=1))
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
actions = batch.get(ACTION)
if actions is not None:
inputs["actions"] = (actions.unsqueeze(1) if actions.ndim == 2 else actions).float()
if (pad := batch.get("action_is_pad")) is not None:
inputs["action_is_pad"] = pad
return examples
state = batch.get(OBS_STATE)
if state is not None:
if state.ndim > 2:
state = state[:, -1, :]
inputs["state"] = (state.unsqueeze(1) if state.ndim == 2 else state).float() # [B, 1, dim]
return inputs
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
native_output = self.model.forward(**self._prepare_model_inputs(batch))
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
@@ -545,15 +424,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state = None
if "state" in examples[0] and examples[0]["state"] is not None:
state = torch.stack([ex["state"] for ex in examples])
actions = self.model.predict_action(batch_images, instructions, state)
inputs = self._prepare_model_inputs(batch)
actions = self.model.predict_action(inputs["images"], inputs["instructions"], inputs.get("state"))
return actions.to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
@@ -121,8 +121,13 @@ class Qwen3VLInterface(torch.nn.Module):
normalization is IDENTITY, so the tensor already arrives in [0, 1]; we pass it
through as float and let the processors normalize (no rescale, no uint8
quantization). A single channel is expanded to 3 to match the RGB processors.
Works for any channels-first layout (channel dim is -3): [C, H, W], [B, C, H, W],
[T, C, H, W], [B, V, T, C, H, W], ...
"""
image = image_tensor.detach().float()
if image.ndim == 3 and image.shape[0] == 1:
image = image.repeat(3, 1, 1)
if image.shape[-3] == 1:
repeats = [1] * image.ndim
repeats[-3] = 3
image = image.repeat(*repeats)
return image
+4 -2
View File
@@ -215,8 +215,10 @@ class _FakeQwenInterface(nn.Module):
@staticmethod
def to_pixel_values(image_tensor: Tensor) -> Tensor:
image = image_tensor.detach().float()
if image.ndim == 3 and image.shape[0] == 1:
image = image.repeat(3, 1, 1)
if image.shape[-3] == 1:
repeats = [1] * image.ndim
repeats[-3] = 3
image = image.repeat(*repeats)
return image
+19 -18
View File
@@ -212,40 +212,41 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
examples = policy._prepare_model_inputs(make_train_batch())
inputs = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1
assert isinstance(ex["image"][0], torch.Tensor) and ex["image"][0].dtype == torch.float32
assert ex["image"][0].ndim == 3 # [C, H, W]
assert isinstance(ex["video"], torch.Tensor)
assert ex["video"].ndim == 5 and ex["video"].dtype == torch.float32 # [V, T, C, H, W]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
assert set(inputs) >= {"images", "instructions", "videos", "actions", "state"}
# images: per-sample, per-view [C, H, W] float tensors (kept as a list for Qwen messages)
assert len(inputs["images"]) == BATCH_SIZE and len(inputs["images"][0]) == 1
img = inputs["images"][0][0]
assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 and img.ndim == 3
assert len(inputs["instructions"]) == BATCH_SIZE
# videos: batched [B, V, T, C, H, W] float
assert inputs["videos"].ndim == 6 and inputs["videos"].shape[0] == BATCH_SIZE
assert inputs["videos"].dtype == torch.float32
assert inputs["actions"].shape == (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
assert inputs["state"].shape == (BATCH_SIZE, 1, STATE_DIM)
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
inputs = policy._prepare_model_inputs(make_inference_batch())
assert "actions" not in inputs and "action_is_pad" not in inputs
assert {"images", "instructions", "state"} <= set(inputs)
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
instructions = policy._prepare_model_inputs(batch)["instructions"]
assert all(isinstance(s, str) and len(s) > 0 for s in instructions)
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
assert policy._prepare_model_inputs(batch)["instructions"] == ["open the drawer"] * BATCH_SIZE
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
@@ -254,7 +255,7 @@ def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: N
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
assert "state" not in policy._prepare_model_inputs(batch)
# ---------------------------------------------------------------------------