mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
iterate on review
This commit is contained in:
+1
-1
@@ -129,7 +129,7 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
xlva = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
|
||||
@@ -378,7 +378,6 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||
|
||||
# is gripper continuous? not bce?
|
||||
gripper_loss = (
|
||||
self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
|
||||
@@ -50,10 +50,11 @@ from transformers.utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig
|
||||
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig
|
||||
from .utils import drop_path
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -665,11 +666,6 @@ class DaViT(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
@@ -2435,116 +2431,6 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The FLORENCE2 vision model without any head""",
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2VisionModel(Florence2PreTrainedModel):
|
||||
def __init__(self, config: Florence2VisionConfig):
|
||||
super().__init__(config)
|
||||
assert config.model_type == "davit", "only DaViT is supported for now"
|
||||
self.vision_tower = DaViT.from_config(config=config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(self, pixel_values):
|
||||
if len(pixel_values.shape) == 4:
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The FLORENCE2 vision model with projection layer""",
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
)
|
||||
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
||||
def __init__(self, config: Florence2VisionConfig):
|
||||
super().__init__(config)
|
||||
assert config.model_type == "davit", "only DaViT is supported for now"
|
||||
self.vision_tower = DaViT.from_config(config=config)
|
||||
|
||||
self._build_image_projection_layers(config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def _build_image_projection_layers(self, config):
|
||||
image_dim_out = config.dim_embed[-1]
|
||||
dim_projection = config.projection_dim
|
||||
self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.image_pos_embed
|
||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
||||
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
self.image_feature_source = config.image_feature_source
|
||||
|
||||
# temporal embedding
|
||||
visual_temporal_embedding_config = config.visual_temporal_embedding
|
||||
if visual_temporal_embedding_config["type"] == "COSINE":
|
||||
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
|
||||
embed_dim=image_dim_out,
|
||||
max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
def forward(self, pixel_values):
|
||||
if len(pixel_values.shape) == 4:
|
||||
batch_size, channels, height, width = pixel_values.shape
|
||||
num_frames = 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
||||
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * num_frames, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
assert h * w == num_tokens, "only support square feature maps for now"
|
||||
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(
|
||||
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
|
||||
)
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
|
||||
1, num_frames, 1, x.shape[-1]
|
||||
)
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict["last_frame"] = x
|
||||
|
||||
new_x = []
|
||||
for _image_feature_source in self.image_feature_source:
|
||||
if _image_feature_source not in x_feat_dict:
|
||||
raise ValueError(f"invalid image feature source: {_image_feature_source}")
|
||||
new_x.append(x_feat_dict[_image_feature_source])
|
||||
|
||||
x = torch.cat(new_x, dim=1)
|
||||
|
||||
x = x @ self.image_projection
|
||||
x = self.image_proj_norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The FLORENCE2 model which consists of a vision backbone and a language model.""",
|
||||
FLORENCE2_START_DOCSTRING,
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -352,7 +353,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
feature = self.config.action_feature
|
||||
if feature is None:
|
||||
return actions
|
||||
desired_dim = feature.shape[0]
|
||||
desired_dim = self.model.dim_action
|
||||
if desired_dim == actions.shape[-1]:
|
||||
return actions
|
||||
if desired_dim < actions.shape[-1]:
|
||||
@@ -434,7 +435,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
instance = cls(config, **kwargs)
|
||||
# step 2: locate model.safetensors
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
logging.info("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
else:
|
||||
try:
|
||||
@@ -455,7 +456,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||
|
||||
print(f"Loading checkpoint from {model_file}")
|
||||
logging.info(f"Loading checkpoint from {model_file}")
|
||||
# step 3: load state dict
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||
@@ -465,7 +466,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
# or deepcopy
|
||||
# step 4: load into instance
|
||||
instance.load_state_dict(state_dict, strict=True)
|
||||
print("Loaded XVLA checkpoint")
|
||||
logging.info("Loaded XVLA checkpoint")
|
||||
# step 5: finalize
|
||||
# Reapply dtype after loading state dict
|
||||
instance.model._apply_dtype()
|
||||
|
||||
@@ -70,8 +70,8 @@ def make_xvla_pre_post_processors(
|
||||
),
|
||||
XVLAImageToFloatProcessorStep(),
|
||||
XVLAImageNetNormalizeProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
XVLAAddDomainIdProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
@@ -426,11 +426,9 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
|
||||
Args:
|
||||
domain_id: The domain ID to add (default: 3)
|
||||
device: Device to place the domain_id tensor on (default: "cuda")
|
||||
"""
|
||||
|
||||
domain_id: int = 0
|
||||
device: str = "cuda"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Add domain_id to complementary data."""
|
||||
@@ -448,7 +446,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
break
|
||||
|
||||
# Add domain_id tensor
|
||||
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long).to(self.device)
|
||||
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
|
||||
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return new_transition
|
||||
@@ -461,7 +459,6 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"domain_id": self.domain_id,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user