From 2e9ba42e1b9fd6ed902b9a271a0b5c10c2f2c0a6 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Mon, 11 May 2026 16:49:18 +0200 Subject: [PATCH] linting --- src/lerobot/policies/vla_jepa/action_head.py | 2 +- .../policies/vla_jepa/modeling_vla_jepa.py | 40 +++++++++---------- .../policies/vla_jepa/qwen_interface.py | 2 +- tests/policies/vla_jepa/test_vla_jepa.py | 14 ++----- 4 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 8e3cc94a9..76d105194 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass import torch -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from diffusers import ConfigMixin, ModelMixin from diffusers.configuration_utils import register_to_config from diffusers.models.attention import Attention, FeedForward diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 90b9b42b4..aa178be51 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -5,7 +5,7 @@ from pathlib import Path import numpy as np import torch -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from PIL import Image from torch import Tensor, nn from transformers import AutoModel, AutoVideoProcessor @@ -139,18 +139,18 @@ class VLAJEPAModel(nn.Module): return_dict=True, ) last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H] - B, _, H = last_hidden.shape + b, _, h = last_hidden.shape - action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H) + action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h) - embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) # ---- Step 2: JEPA Encoder (same as original) ---- - B, V, T_frames, C, H_img, W_img = batch_videos.shape - batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img) + b, v, t_frames, c, h_img, w_img = batch_videos.shape + batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) video_pixels = [] - for i in range(B * V): + for i in range(b * v): video_pixels.append( self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ "pixel_values_videos" @@ -161,41 +161,41 @@ class VLAJEPAModel(nn.Module): with torch.no_grad(): video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] - video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2) + video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) # ---- Step 3: JEPA Predictor (same as original) ---- tubelet_size = self.video_encoder.config.tubelet_size - T_enc = T_frames // tubelet_size + t_enc = t_frames // tubelet_size device_wm = video_embeddings.device - if T_enc < 2: + if t_enc < 2: # Not enough frames for JEPA prediction (need at least 2 encoded frames) wm_loss = torch.tensor(0.0, device=device_wm) else: - tokens_per_frame = video_embeddings.shape[1] // T_enc + tokens_per_frame = video_embeddings.shape[1] // t_enc # input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D] # gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D] - input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :] + input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :] gt_states = video_embeddings[:, tokens_per_frame:, :] - D_emb = input_states.shape[-1] + d_emb = input_states.shape[-1] # Reshape to 4D for ActionConditionedVideoPredictor: # [B, (T-1)*tokens, D] → [B, T-1, tokens, D] - input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb) + input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb) # Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D] - expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep + expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep if action_tokens.shape[1] < expected_actions: pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) action_tokens = torch.cat([action_tokens, pad], dim=1) act_4d = action_tokens[:, :expected_actions].view( - B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1 + b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1 ) # Cast to float32 for predictor (Linear layers are float32) pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) - predicted_states = pred_4d.reshape(B, -1, D_emb) + predicted_states = pred_4d.reshape(b, -1, d_emb) wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") @@ -261,8 +261,8 @@ class VLAJEPAModel(nn.Module): return_dict=True, ) last_hidden = qwen_outputs.hidden_states[-1] - B, _, H = last_hidden.shape - embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H) + b, _, h = last_hidden.shape + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) state_tensor = None if state is not None: @@ -349,8 +349,6 @@ class VLAJEPAPolicy(PreTrainedPolicy): # ---- Collect videos per sample ---- # Build video arrays: for each sample, stack views as [V, T, H, W, 3] - num_views = len(image_keys) - has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch) # Check whether any image feature has a time dimension video_source = None for k in image_keys: diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index 1e1e7a895..044b6f989 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from collections.abc import Sequence import numpy as np import torch diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 0a42013b6..48c8ab9b4 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -17,7 +17,6 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE - pytestmark = pytest.mark.filterwarnings( "ignore:In CPU autocast, but the target dtype is not supported:UserWarning" ) @@ -81,10 +80,7 @@ class _FakeQwenInterface(nn.Module): def expand_tokenizer(self) -> tuple[list[str], list[int], int]: max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep - action_tokens = [ - self.config.special_action_token.format(idx) - for idx in range(max_action_tokens) - ] + action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)] action_token_ids = list(range(1000, 1000 + max_action_tokens)) return action_tokens, action_token_ids, 2000 @@ -226,9 +222,7 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) -> loss.backward() assert any( - param.grad is not None - for param in policy.model.action_model.parameters() - if param.requires_grad + param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad ) assert set(batch) == set(batch_before) for key, value in batch.items(): @@ -299,9 +293,7 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None: local_files_only=True, ) except LocalEntryNotFoundError: - pytest.skip( - f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache." - ) + pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.") try: checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)