iterate on review

This commit is contained in:
Jade Choghari
2025-11-28 10:16:11 +01:00
parent 9cdf46bd3d
commit 4ad41f7a76
5 changed files with 10 additions and 127 deletions
+1 -1
View File
@@ -129,7 +129,7 @@ groot = [
"ninja>=1.11.1,<2.0.0", "ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
] ]
xlva = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features # Features
-1
View File
@@ -378,7 +378,6 @@ class BimanualSO101ActionSpace(BaseActionSpace):
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6]) left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12]) right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
# is gripper continuous? not bce?
gripper_loss = ( gripper_loss = (
self.mse( self.mse(
pred[:, :, [5, 11]], pred[:, :, [5, 11]],
+2 -116
View File
@@ -50,10 +50,11 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig from .configuration_florence2 import Florence2Config, Florence2LanguageConfig
from .utils import drop_path from .utils import drop_path
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -665,11 +666,6 @@ class DaViT(nn.Module):
) )
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# Copied from transformers.models.llama.modeling_llama._get_unpad_data # Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -2435,116 +2431,6 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
""" """
@add_start_docstrings(
"""The FLORENCE2 vision model without any head""",
FLORENCE2_START_DOCSTRING,
)
class Florence2VisionModel(Florence2PreTrainedModel):
def __init__(self, config: Florence2VisionConfig):
super().__init__(config)
assert config.model_type == "davit", "only DaViT is supported for now"
self.vision_tower = DaViT.from_config(config=config)
self.post_init()
def forward(self, pixel_values):
if len(pixel_values.shape) == 4:
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
raise ValueError(f"invalid image shape {pixel_values.shape}")
return x
@add_start_docstrings(
"""The FLORENCE2 vision model with projection layer""",
FLORENCE2_START_DOCSTRING,
)
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
def __init__(self, config: Florence2VisionConfig):
super().__init__(config)
assert config.model_type == "davit", "only DaViT is supported for now"
self.vision_tower = DaViT.from_config(config=config)
self._build_image_projection_layers(config)
self.post_init()
def _build_image_projection_layers(self, config):
image_dim_out = config.dim_embed[-1]
dim_projection = config.projection_dim
self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
self.image_proj_norm = nn.LayerNorm(dim_projection)
image_pos_embed_config = config.image_pos_embed
if image_pos_embed_config["type"] == "learned_abs_2d":
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
)
else:
raise NotImplementedError("Not implemented yet")
self.image_feature_source = config.image_feature_source
# temporal embedding
visual_temporal_embedding_config = config.visual_temporal_embedding
if visual_temporal_embedding_config["type"] == "COSINE":
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
embed_dim=image_dim_out,
max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
)
else:
raise NotImplementedError("Not implemented yet")
def forward(self, pixel_values):
if len(pixel_values.shape) == 4:
batch_size, channels, height, width = pixel_values.shape
num_frames = 1
x = self.vision_tower.forward_features_unpool(pixel_values)
else:
raise ValueError(f"invalid image shape {pixel_values.shape}")
if self.image_pos_embed is not None:
x = x.view(batch_size * num_frames, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, "only support square feature maps for now"
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
)
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
1, num_frames, 1, x.shape[-1]
)
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
x_feat_dict["last_frame"] = x
new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError(f"invalid image feature source: {_image_feature_source}")
new_x.append(x_feat_dict[_image_feature_source])
x = torch.cat(new_x, dim=1)
x = x @ self.image_projection
x = self.image_proj_norm(x)
return x
@add_start_docstrings( @add_start_docstrings(
"""The FLORENCE2 model which consists of a vision backbone and a language model.""", """The FLORENCE2 model which consists of a vision backbone and a language model.""",
FLORENCE2_START_DOCSTRING, FLORENCE2_START_DOCSTRING,
+5 -4
View File
@@ -19,6 +19,7 @@
from __future__ import annotations from __future__ import annotations
import builtins import builtins
import logging
import os import os
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@@ -352,7 +353,7 @@ class XVLAPolicy(PreTrainedPolicy):
feature = self.config.action_feature feature = self.config.action_feature
if feature is None: if feature is None:
return actions return actions
desired_dim = feature.shape[0] desired_dim = self.model.dim_action
if desired_dim == actions.shape[-1]: if desired_dim == actions.shape[-1]:
return actions return actions
if desired_dim < actions.shape[-1]: if desired_dim < actions.shape[-1]:
@@ -434,7 +435,7 @@ class XVLAPolicy(PreTrainedPolicy):
instance = cls(config, **kwargs) instance = cls(config, **kwargs)
# step 2: locate model.safetensors # step 2: locate model.safetensors
if os.path.isdir(model_id): if os.path.isdir(model_id):
print("Loading weights from local directory") logging.info("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors") model_file = os.path.join(model_id, "model.safetensors")
else: else:
try: try:
@@ -455,7 +456,7 @@ class XVLAPolicy(PreTrainedPolicy):
except HfHubHTTPError as e: except HfHubHTTPError as e:
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
print(f"Loading checkpoint from {model_file}") logging.info(f"Loading checkpoint from {model_file}")
# step 3: load state dict # step 3: load state dict
state_dict = safetensors.torch.load_file(model_file) state_dict = safetensors.torch.load_file(model_file)
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight" encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
@@ -465,7 +466,7 @@ class XVLAPolicy(PreTrainedPolicy):
# or deepcopy # or deepcopy
# step 4: load into instance # step 4: load into instance
instance.load_state_dict(state_dict, strict=True) instance.load_state_dict(state_dict, strict=True)
print("Loaded XVLA checkpoint") logging.info("Loaded XVLA checkpoint")
# step 5: finalize # step 5: finalize
# Reapply dtype after loading state dict # Reapply dtype after loading state dict
instance.model._apply_dtype() instance.model._apply_dtype()
+2 -5
View File
@@ -70,8 +70,8 @@ def make_xvla_pre_post_processors(
), ),
XVLAImageToFloatProcessorStep(), XVLAImageToFloatProcessorStep(),
XVLAImageNetNormalizeProcessorStep(), XVLAImageNetNormalizeProcessorStep(),
DeviceProcessorStep(device=config.device),
XVLAAddDomainIdProcessorStep(), XVLAAddDomainIdProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats features=features, norm_map=config.normalization_mapping, stats=dataset_stats
), ),
@@ -426,11 +426,9 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
Args: Args:
domain_id: The domain ID to add (default: 3) domain_id: The domain ID to add (default: 3)
device: Device to place the domain_id tensor on (default: "cuda")
""" """
domain_id: int = 0 domain_id: int = 0
device: str = "cuda"
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add domain_id to complementary data.""" """Add domain_id to complementary data."""
@@ -448,7 +446,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
break break
# Add domain_id tensor # Add domain_id tensor
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long).to(self.device) comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return new_transition return new_transition
@@ -461,7 +459,6 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
"""Return serializable configuration.""" """Return serializable configuration."""
return { return {
"domain_id": self.domain_id, "domain_id": self.domain_id,
"device": self.device,
} }