From 0edbb68ec3baf976b379f4dcac840498a83a1135 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Thu, 21 May 2026 13:16:57 +0200 Subject: [PATCH] pre-commit cleanup --- src/lerobot/policies/vla_jepa/action_head.py | 29 ++++++++++--------- .../vla_jepa/convert_vla_jepa_checkpoints.py | 12 ++++---- .../policies/vla_jepa/modeling_vla_jepa.py | 1 - src/lerobot/policies/vla_jepa/world_model.py | 16 +++++++--- tests/policies/vla_jepa/test_vla_jepa.py | 3 +- tests/policies/vla_jepa/test_world_model.py | 4 ++- 6 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 200ecdd91..0f17e7845 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -34,7 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x) - class SinusoidalPositionalEncoding(nn.Module): def __init__(self, embedding_dim: int): super().__init__() @@ -215,26 +214,28 @@ class VLAJEPAActionHead(nn.Module): ) self.action_encoder = ActionEncoder(config.action_dim, inner_dim) self.action_decoder = nn.Sequential( - OrderedDict([ - ("layer1", nn.Linear(hidden_size, hidden_size)), - ("relu", nn.ReLU()), - ("layer2", nn.Linear(hidden_size, config.action_dim)), - ]) + OrderedDict( + [ + ("layer1", nn.Linear(hidden_size, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, config.action_dim)), + ] + ) ) self.state_encoder = ( nn.Sequential( - OrderedDict([ - ("layer1", nn.Linear(config.state_dim, hidden_size)), - ("relu", nn.ReLU()), - ("layer2", nn.Linear(hidden_size, inner_dim)), - ]) + OrderedDict( + [ + ("layer1", nn.Linear(config.state_dim, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, inner_dim)), + ] + ) ) if config.state_dim > 0 else None ) - self.future_tokens = nn.Embedding( - config.num_embodied_action_tokens_per_instruction, inner_dim - ) + self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim) self.position_embedding = nn.Embedding( max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4), inner_dim, diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index 753291bb5..90120c9bf 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -31,6 +31,7 @@ from pathlib import Path import torch from huggingface_hub import HfApi from safetensors.torch import save_file as save_safetensors + from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors # --------------------------------------------------------------------------- @@ -54,7 +55,7 @@ log = logging.getLogger(__name__) def _normalize_source_key(key: str) -> str: - return key[len("module."):] if key.startswith("module.") else key + return key[len("module.") :] if key.startswith("module.") else key def _map_checkpoint_key(raw_key: str) -> str | None: @@ -62,11 +63,11 @@ def _map_checkpoint_key(raw_key: str) -> str | None: key = _normalize_source_key(raw_key) if key.startswith("qwen_vl_interface."): - return "model.qwen." + key[len("qwen_vl_interface."):] + return "model.qwen." + key[len("qwen_vl_interface.") :] if key.startswith("vj_encoder."): - return "model.video_encoder." + key[len("vj_encoder."):] + return "model.video_encoder." + key[len("vj_encoder.") :] if key.startswith("vj_predictor."): - return "model.video_predictor." + key[len("vj_predictor."):] + return "model.video_predictor." + key[len("vj_predictor.") :] if key.startswith("action_model."): # LeRobot code uses the same sub-key names as the source checkpoint, # so only the top-level "model." prefix needs to be added. @@ -94,7 +95,6 @@ def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict return None - # --------------------------------------------------------------------------- # Architecture — identical across all 4 variants (from config.json) # --------------------------------------------------------------------------- @@ -295,7 +295,7 @@ def main() -> None: config._save_pretrained(save_dir) # writes config.json via draccus preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats) - preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json + preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json log.info(" Uploading …") diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 2f0ebbb4b..7d728b774 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -9,7 +9,6 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 from PIL import Image -from safetensors.torch import load_file as load_safetensors_file from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy, T diff --git a/src/lerobot/policies/vla_jepa/world_model.py b/src/lerobot/policies/vla_jepa/world_model.py index 1df495e82..9359f188a 100644 --- a/src/lerobot/policies/vla_jepa/world_model.py +++ b/src/lerobot/policies/vla_jepa/world_model.py @@ -159,12 +159,20 @@ class ACRoPEAttention(nn.Module): action_token = x[:, :, idx : idx + 1, :].flatten(1, 2) qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)) - kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)) + qd = rotate_queries_or_keys( + q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device) + ) + kd = rotate_queries_or_keys( + k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device) + ) qr = q[..., self.d_dim :] kr = k[..., self.d_dim :] - action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)) - action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)) + action_q.append( + torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1) + ) + action_k.append( + torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1) + ) action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1)) action_q = torch.cat(action_q, dim=3).flatten(2, 3) diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 37ea46da5..52b00697c 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -403,9 +403,9 @@ def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None: """ClipActionsProcessorStep clamps to [-1, 1] before unnormalization.""" from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep from lerobot.processor import UnnormalizerProcessorStep from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action - from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep from lerobot.utils.constants import ACTION dataset_stats = _make_dataset_stats() @@ -466,6 +466,7 @@ def test_postprocessor_applied_after_predict_action_chunk( # Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i unnormed = postprocessor(chunk) from lerobot.utils.constants import ACTION + a_min = dataset_stats[ACTION]["min"].numpy() a_max = dataset_stats[ACTION]["max"].numpy() expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0] diff --git a/tests/policies/vla_jepa/test_world_model.py b/tests/policies/vla_jepa/test_world_model.py index 0c341b993..555b2cd11 100644 --- a/tests/policies/vla_jepa/test_world_model.py +++ b/tests/policies/vla_jepa/test_world_model.py @@ -43,7 +43,9 @@ def _make_predictor( ], ) def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None: - predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame) + predictor = _make_predictor( + embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame + ) frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim) action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM) out = predictor(frame_tokens, action_tokens)