pre-commit cleanup

This commit is contained in:
Maximellerbach
2026-05-21 13:16:57 +02:00
parent 01ce5d7af1
commit 7da594fda8
6 changed files with 38 additions and 27 deletions
+15 -14
View File
@@ -34,7 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
@@ -215,26 +214,28 @@ class VLAJEPAActionHead(nn.Module):
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict([
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
])
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict([
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
])
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(
config.num_embodied_action_tokens_per_instruction, inner_dim
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
@@ -31,6 +31,7 @@ from pathlib import Path
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file as save_safetensors
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
# ---------------------------------------------------------------------------
@@ -54,7 +55,7 @@ log = logging.getLogger(__name__)
def _normalize_source_key(key: str) -> str:
return key[len("module."):] if key.startswith("module.") else key
return key[len("module.") :] if key.startswith("module.") else key
def _map_checkpoint_key(raw_key: str) -> str | None:
@@ -62,11 +63,11 @@ def _map_checkpoint_key(raw_key: str) -> str | None:
key = _normalize_source_key(raw_key)
if key.startswith("qwen_vl_interface."):
return "model.qwen." + key[len("qwen_vl_interface."):]
return "model.qwen." + key[len("qwen_vl_interface.") :]
if key.startswith("vj_encoder."):
return "model.video_encoder." + key[len("vj_encoder."):]
return "model.video_encoder." + key[len("vj_encoder.") :]
if key.startswith("vj_predictor."):
return "model.video_predictor." + key[len("vj_predictor."):]
return "model.video_predictor." + key[len("vj_predictor.") :]
if key.startswith("action_model."):
# LeRobot code uses the same sub-key names as the source checkpoint,
# so only the top-level "model." prefix needs to be added.
@@ -94,7 +95,6 @@ def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict
return None
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
@@ -295,7 +295,7 @@ def main() -> None:
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
log.info(" Uploading …")
@@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file as load_safetensors_file
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
+12 -4
View File
@@ -159,12 +159,20 @@ class ACRoPEAttention(nn.Module):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
+2 -1
View File
@@ -403,9 +403,9 @@ def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
@@ -466,6 +466,7 @@ def test_postprocessor_applied_after_predict_action_chunk(
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
+3 -1
View File
@@ -43,7 +43,9 @@ def _make_predictor(
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame)
predictor = _make_predictor(
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)