mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
pre-commit cleanup
This commit is contained in:
@@ -34,7 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
@@ -215,26 +214,28 @@ class VLAJEPAActionHead(nn.Module):
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict([
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
])
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
])
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(
|
||||
config.num_embodied_action_tokens_per_instruction, inner_dim
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
|
||||
@@ -31,6 +31,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file as save_safetensors
|
||||
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -54,7 +55,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_source_key(key: str) -> str:
|
||||
return key[len("module."):] if key.startswith("module.") else key
|
||||
return key[len("module.") :] if key.startswith("module.") else key
|
||||
|
||||
|
||||
def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||
@@ -62,11 +63,11 @@ def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||
key = _normalize_source_key(raw_key)
|
||||
|
||||
if key.startswith("qwen_vl_interface."):
|
||||
return "model.qwen." + key[len("qwen_vl_interface."):]
|
||||
return "model.qwen." + key[len("qwen_vl_interface.") :]
|
||||
if key.startswith("vj_encoder."):
|
||||
return "model.video_encoder." + key[len("vj_encoder."):]
|
||||
return "model.video_encoder." + key[len("vj_encoder.") :]
|
||||
if key.startswith("vj_predictor."):
|
||||
return "model.video_predictor." + key[len("vj_predictor."):]
|
||||
return "model.video_predictor." + key[len("vj_predictor.") :]
|
||||
if key.startswith("action_model."):
|
||||
# LeRobot code uses the same sub-key names as the source checkpoint,
|
||||
# so only the top-level "model." prefix needs to be added.
|
||||
@@ -94,7 +95,6 @@ def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict
|
||||
return None
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Architecture — identical across all 4 variants (from config.json)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -295,7 +295,7 @@ def main() -> None:
|
||||
config._save_pretrained(save_dir) # writes config.json via draccus
|
||||
|
||||
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
|
||||
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
|
||||
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
|
||||
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
|
||||
|
||||
log.info(" Uploading …")
|
||||
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors_file
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
@@ -159,12 +159,20 @@ class ACRoPEAttention(nn.Module):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
|
||||
kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
|
||||
@@ -403,9 +403,9 @@ def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None
|
||||
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
@@ -466,6 +466,7 @@ def test_postprocessor_applied_after_predict_action_chunk(
|
||||
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||
unnormed = postprocessor(chunk)
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||
|
||||
@@ -43,7 +43,9 @@ def _make_predictor(
|
||||
],
|
||||
)
|
||||
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame)
|
||||
predictor = _make_predictor(
|
||||
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||
)
|
||||
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||
out = predictor(frame_tokens, action_tokens)
|
||||
|
||||
Reference in New Issue
Block a user