This commit is contained in:
Maximellerbach
2026-05-11 16:49:18 +02:00
parent 3144029814
commit 2e9ba42e1b
4 changed files with 24 additions and 34 deletions
+1 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward from diffusers.models.attention import Attention, FeedForward
@@ -5,7 +5,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F # noqa: N812
from PIL import Image from PIL import Image
from torch import Tensor, nn from torch import Tensor, nn
from transformers import AutoModel, AutoVideoProcessor from transformers import AutoModel, AutoVideoProcessor
@@ -139,18 +139,18 @@ class VLAJEPAModel(nn.Module):
return_dict=True, return_dict=True,
) )
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H] last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
B, _, H = last_hidden.shape b, _, h = last_hidden.shape
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H) action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2: JEPA Encoder (same as original) ---- # ---- Step 2: JEPA Encoder (same as original) ----
B, V, T_frames, C, H_img, W_img = batch_videos.shape b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img) batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = [] video_pixels = []
for i in range(B * V): for i in range(b * v):
video_pixels.append( video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos" "pixel_values_videos"
@@ -161,41 +161,41 @@ class VLAJEPAModel(nn.Module):
with torch.no_grad(): with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim] # Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2) video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
# ---- Step 3: JEPA Predictor (same as original) ---- # ---- Step 3: JEPA Predictor (same as original) ----
tubelet_size = self.video_encoder.config.tubelet_size tubelet_size = self.video_encoder.config.tubelet_size
T_enc = T_frames // tubelet_size t_enc = t_frames // tubelet_size
device_wm = video_embeddings.device device_wm = video_embeddings.device
if T_enc < 2: if t_enc < 2:
# Not enough frames for JEPA prediction (need at least 2 encoded frames) # Not enough frames for JEPA prediction (need at least 2 encoded frames)
wm_loss = torch.tensor(0.0, device=device_wm) wm_loss = torch.tensor(0.0, device=device_wm)
else: else:
tokens_per_frame = video_embeddings.shape[1] // T_enc tokens_per_frame = video_embeddings.shape[1] // t_enc
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D] # input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D] # gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :] input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :]
gt_states = video_embeddings[:, tokens_per_frame:, :] gt_states = video_embeddings[:, tokens_per_frame:, :]
D_emb = input_states.shape[-1] d_emb = input_states.shape[-1]
# Reshape to 4D for ActionConditionedVideoPredictor: # Reshape to 4D for ActionConditionedVideoPredictor:
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D] # [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb) input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb)
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D] # Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions: if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1) action_tokens = torch.cat([action_tokens, pad], dim=1)
act_4d = action_tokens[:, :expected_actions].view( act_4d = action_tokens[:, :expected_actions].view(
B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1 b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1
) )
# Cast to float32 for predictor (Linear layers are float32) # Cast to float32 for predictor (Linear layers are float32)
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
predicted_states = pred_4d.reshape(B, -1, D_emb) predicted_states = pred_4d.reshape(b, -1, d_emb)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
@@ -261,8 +261,8 @@ class VLAJEPAModel(nn.Module):
return_dict=True, return_dict=True,
) )
last_hidden = qwen_outputs.hidden_states[-1] last_hidden = qwen_outputs.hidden_states[-1]
B, _, H = last_hidden.shape b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None state_tensor = None
if state is not None: if state is not None:
@@ -349,8 +349,6 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Collect videos per sample ---- # ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3] # Build video arrays: for each sample, stack views as [V, T, H, W, 3]
num_views = len(image_keys)
has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch)
# Check whether any image feature has a time dimension # Check whether any image feature has a time dimension
video_source = None video_source = None
for k in image_keys: for k in image_keys:
@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from collections.abc import Sequence
import numpy as np import numpy as np
import torch import torch
+3 -11
View File
@@ -17,7 +17,6 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning" "ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
) )
@@ -81,10 +80,7 @@ class _FakeQwenInterface(nn.Module):
def expand_tokenizer(self) -> tuple[list[str], list[int], int]: def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [ action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
self.config.special_action_token.format(idx)
for idx in range(max_action_tokens)
]
action_token_ids = list(range(1000, 1000 + max_action_tokens)) action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000 return action_tokens, action_token_ids, 2000
@@ -226,9 +222,7 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
loss.backward() loss.backward()
assert any( assert any(
param.grad is not None param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad
for param in policy.model.action_model.parameters()
if param.requires_grad
) )
assert set(batch) == set(batch_before) assert set(batch) == set(batch_before)
for key, value in batch.items(): for key, value in batch.items():
@@ -299,9 +293,7 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
local_files_only=True, local_files_only=True,
) )
except LocalEntryNotFoundError: except LocalEntryNotFoundError:
pytest.skip( pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.")
f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache."
)
try: try:
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False) checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)