mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
linting
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F # noqa: N812
|
||||||
from diffusers import ConfigMixin, ModelMixin
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
from diffusers.configuration_utils import register_to_config
|
from diffusers.configuration_utils import register_to_config
|
||||||
from diffusers.models.attention import Attention, FeedForward
|
from diffusers.models.attention import Attention, FeedForward
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F # noqa: N812
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import AutoModel, AutoVideoProcessor
|
from transformers import AutoModel, AutoVideoProcessor
|
||||||
@@ -139,18 +139,18 @@ class VLAJEPAModel(nn.Module):
|
|||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
|
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
|
||||||
B, _, H = last_hidden.shape
|
b, _, h = last_hidden.shape
|
||||||
|
|
||||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(B, -1, H)
|
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
|
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
# ---- Step 2: JEPA Encoder (same as original) ----
|
# ---- Step 2: JEPA Encoder (same as original) ----
|
||||||
B, V, T_frames, C, H_img, W_img = batch_videos.shape
|
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||||
batch_videos_flat = batch_videos.reshape(B * V, T_frames, C, H_img, W_img)
|
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||||
|
|
||||||
video_pixels = []
|
video_pixels = []
|
||||||
for i in range(B * V):
|
for i in range(b * v):
|
||||||
video_pixels.append(
|
video_pixels.append(
|
||||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||||
"pixel_values_videos"
|
"pixel_values_videos"
|
||||||
@@ -161,41 +161,41 @@ class VLAJEPAModel(nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=V, dim=0), dim=2)
|
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||||
|
|
||||||
# ---- Step 3: JEPA Predictor (same as original) ----
|
# ---- Step 3: JEPA Predictor (same as original) ----
|
||||||
tubelet_size = self.video_encoder.config.tubelet_size
|
tubelet_size = self.video_encoder.config.tubelet_size
|
||||||
T_enc = T_frames // tubelet_size
|
t_enc = t_frames // tubelet_size
|
||||||
device_wm = video_embeddings.device
|
device_wm = video_embeddings.device
|
||||||
|
|
||||||
if T_enc < 2:
|
if t_enc < 2:
|
||||||
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
|
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
|
||||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||||
else:
|
else:
|
||||||
tokens_per_frame = video_embeddings.shape[1] // T_enc
|
tokens_per_frame = video_embeddings.shape[1] // t_enc
|
||||||
|
|
||||||
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
|
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
|
||||||
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
|
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
|
||||||
input_states = video_embeddings[:, : tokens_per_frame * (T_enc - 1), :]
|
input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :]
|
||||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||||
D_emb = input_states.shape[-1]
|
d_emb = input_states.shape[-1]
|
||||||
|
|
||||||
# Reshape to 4D for ActionConditionedVideoPredictor:
|
# Reshape to 4D for ActionConditionedVideoPredictor:
|
||||||
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
|
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
|
||||||
input_states_4d = input_states.view(B, T_enc - 1, tokens_per_frame, D_emb)
|
input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb)
|
||||||
|
|
||||||
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
|
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
|
||||||
expected_actions = (T_enc - 1) * self.config.num_action_tokens_per_timestep
|
expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep
|
||||||
if action_tokens.shape[1] < expected_actions:
|
if action_tokens.shape[1] < expected_actions:
|
||||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||||
act_4d = action_tokens[:, :expected_actions].view(
|
act_4d = action_tokens[:, :expected_actions].view(
|
||||||
B, T_enc - 1, self.config.num_action_tokens_per_timestep, -1
|
b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cast to float32 for predictor (Linear layers are float32)
|
# Cast to float32 for predictor (Linear layers are float32)
|
||||||
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
||||||
predicted_states = pred_4d.reshape(B, -1, D_emb)
|
predicted_states = pred_4d.reshape(b, -1, d_emb)
|
||||||
|
|
||||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||||
|
|
||||||
@@ -261,8 +261,8 @@ class VLAJEPAModel(nn.Module):
|
|||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
last_hidden = qwen_outputs.hidden_states[-1]
|
last_hidden = qwen_outputs.hidden_states[-1]
|
||||||
B, _, H = last_hidden.shape
|
b, _, h = last_hidden.shape
|
||||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(B, -1, H)
|
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
state_tensor = None
|
state_tensor = None
|
||||||
if state is not None:
|
if state is not None:
|
||||||
@@ -349,8 +349,6 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# ---- Collect videos per sample ----
|
# ---- Collect videos per sample ----
|
||||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||||
num_views = len(image_keys)
|
|
||||||
has_video = any(batch[k].ndim == 5 for k in image_keys if k in batch)
|
|
||||||
# Check whether any image feature has a time dimension
|
# Check whether any image feature has a time dimension
|
||||||
video_source = None
|
video_source = None
|
||||||
for k in image_keys:
|
for k in image_keys:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
|||||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.filterwarnings(
|
pytestmark = pytest.mark.filterwarnings(
|
||||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||||
)
|
)
|
||||||
@@ -81,10 +80,7 @@ class _FakeQwenInterface(nn.Module):
|
|||||||
|
|
||||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||||
action_tokens = [
|
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||||
self.config.special_action_token.format(idx)
|
|
||||||
for idx in range(max_action_tokens)
|
|
||||||
]
|
|
||||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||||
return action_tokens, action_token_ids, 2000
|
return action_tokens, action_token_ids, 2000
|
||||||
|
|
||||||
@@ -226,9 +222,7 @@ def test_vla_jepa_training_forward_pass(patch_vla_jepa_external_models: None) ->
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
assert any(
|
assert any(
|
||||||
param.grad is not None
|
param.grad is not None for param in policy.model.action_model.parameters() if param.requires_grad
|
||||||
for param in policy.model.action_model.parameters()
|
|
||||||
if param.requires_grad
|
|
||||||
)
|
)
|
||||||
assert set(batch) == set(batch_before)
|
assert set(batch) == set(batch_before)
|
||||||
for key, value in batch.items():
|
for key, value in batch.items():
|
||||||
@@ -299,9 +293,7 @@ def test_vla_jepa_pretrained_checkpoint_loads_from_hf_cache() -> None:
|
|||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
except LocalEntryNotFoundError:
|
except LocalEntryNotFoundError:
|
||||||
pytest.skip(
|
pytest.skip(f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache.")
|
||||||
f"{repo_id}/{checkpoint_filename} is not available in the local Hugging Face cache."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True, weights_only=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user