mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
iterate on review
This commit is contained in:
+1
-1
@@ -129,7 +129,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
xlva = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
|
|||||||
@@ -378,7 +378,6 @@ class BimanualSO101ActionSpace(BaseActionSpace):
|
|||||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||||
|
|
||||||
# is gripper continuous? not bce?
|
|
||||||
gripper_loss = (
|
gripper_loss = (
|
||||||
self.mse(
|
self.mse(
|
||||||
pred[:, :, [5, 11]],
|
pred[:, :, [5, 11]],
|
||||||
|
|||||||
@@ -50,10 +50,11 @@ from transformers.utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig
|
from .configuration_florence2 import Florence2Config, Florence2LanguageConfig
|
||||||
from .utils import drop_path
|
from .utils import drop_path
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -665,11 +666,6 @@ class DaViT(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
||||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||||
def _get_unpad_data(attention_mask):
|
def _get_unpad_data(attention_mask):
|
||||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
@@ -2435,116 +2431,6 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"""The FLORENCE2 vision model without any head""",
|
|
||||||
FLORENCE2_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class Florence2VisionModel(Florence2PreTrainedModel):
|
|
||||||
def __init__(self, config: Florence2VisionConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
assert config.model_type == "davit", "only DaViT is supported for now"
|
|
||||||
self.vision_tower = DaViT.from_config(config=config)
|
|
||||||
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
if len(pixel_values.shape) == 4:
|
|
||||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"""The FLORENCE2 vision model with projection layer""",
|
|
||||||
FLORENCE2_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
||||||
def __init__(self, config: Florence2VisionConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
assert config.model_type == "davit", "only DaViT is supported for now"
|
|
||||||
self.vision_tower = DaViT.from_config(config=config)
|
|
||||||
|
|
||||||
self._build_image_projection_layers(config)
|
|
||||||
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def _build_image_projection_layers(self, config):
|
|
||||||
image_dim_out = config.dim_embed[-1]
|
|
||||||
dim_projection = config.projection_dim
|
|
||||||
self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
|
|
||||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
|
||||||
image_pos_embed_config = config.image_pos_embed
|
|
||||||
if image_pos_embed_config["type"] == "learned_abs_2d":
|
|
||||||
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
|
|
||||||
embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Not implemented yet")
|
|
||||||
|
|
||||||
self.image_feature_source = config.image_feature_source
|
|
||||||
|
|
||||||
# temporal embedding
|
|
||||||
visual_temporal_embedding_config = config.visual_temporal_embedding
|
|
||||||
if visual_temporal_embedding_config["type"] == "COSINE":
|
|
||||||
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
|
|
||||||
embed_dim=image_dim_out,
|
|
||||||
max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Not implemented yet")
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
if len(pixel_values.shape) == 4:
|
|
||||||
batch_size, channels, height, width = pixel_values.shape
|
|
||||||
num_frames = 1
|
|
||||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"invalid image shape {pixel_values.shape}")
|
|
||||||
|
|
||||||
if self.image_pos_embed is not None:
|
|
||||||
x = x.view(batch_size * num_frames, -1, x.shape[-1])
|
|
||||||
num_tokens = x.shape[-2]
|
|
||||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
|
||||||
assert h * w == num_tokens, "only support square feature maps for now"
|
|
||||||
x = x.view(batch_size * num_frames, h, w, x.shape[-1])
|
|
||||||
pos_embed = self.image_pos_embed(x)
|
|
||||||
x = x + pos_embed
|
|
||||||
x = x.view(batch_size, num_frames * h * w, x.shape[-1])
|
|
||||||
|
|
||||||
if self.visual_temporal_embed is not None:
|
|
||||||
visual_temporal_embed = self.visual_temporal_embed(
|
|
||||||
x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
|
|
||||||
)
|
|
||||||
x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
|
|
||||||
1, num_frames, 1, x.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
x_feat_dict = {}
|
|
||||||
|
|
||||||
spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
|
|
||||||
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
|
||||||
|
|
||||||
temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
|
|
||||||
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
|
||||||
|
|
||||||
x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
|
|
||||||
x_feat_dict["last_frame"] = x
|
|
||||||
|
|
||||||
new_x = []
|
|
||||||
for _image_feature_source in self.image_feature_source:
|
|
||||||
if _image_feature_source not in x_feat_dict:
|
|
||||||
raise ValueError(f"invalid image feature source: {_image_feature_source}")
|
|
||||||
new_x.append(x_feat_dict[_image_feature_source])
|
|
||||||
|
|
||||||
x = torch.cat(new_x, dim=1)
|
|
||||||
|
|
||||||
x = x @ self.image_projection
|
|
||||||
x = self.image_proj_norm(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""The FLORENCE2 model which consists of a vision backbone and a language model.""",
|
"""The FLORENCE2 model which consists of a vision backbone and a language model.""",
|
||||||
FLORENCE2_START_DOCSTRING,
|
FLORENCE2_START_DOCSTRING,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -352,7 +353,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
feature = self.config.action_feature
|
feature = self.config.action_feature
|
||||||
if feature is None:
|
if feature is None:
|
||||||
return actions
|
return actions
|
||||||
desired_dim = feature.shape[0]
|
desired_dim = self.model.dim_action
|
||||||
if desired_dim == actions.shape[-1]:
|
if desired_dim == actions.shape[-1]:
|
||||||
return actions
|
return actions
|
||||||
if desired_dim < actions.shape[-1]:
|
if desired_dim < actions.shape[-1]:
|
||||||
@@ -434,7 +435,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
instance = cls(config, **kwargs)
|
instance = cls(config, **kwargs)
|
||||||
# step 2: locate model.safetensors
|
# step 2: locate model.safetensors
|
||||||
if os.path.isdir(model_id):
|
if os.path.isdir(model_id):
|
||||||
print("Loading weights from local directory")
|
logging.info("Loading weights from local directory")
|
||||||
model_file = os.path.join(model_id, "model.safetensors")
|
model_file = os.path.join(model_id, "model.safetensors")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -455,7 +456,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
except HfHubHTTPError as e:
|
except HfHubHTTPError as e:
|
||||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||||
|
|
||||||
print(f"Loading checkpoint from {model_file}")
|
logging.info(f"Loading checkpoint from {model_file}")
|
||||||
# step 3: load state dict
|
# step 3: load state dict
|
||||||
state_dict = safetensors.torch.load_file(model_file)
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||||
@@ -465,7 +466,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
# or deepcopy
|
# or deepcopy
|
||||||
# step 4: load into instance
|
# step 4: load into instance
|
||||||
instance.load_state_dict(state_dict, strict=True)
|
instance.load_state_dict(state_dict, strict=True)
|
||||||
print("Loaded XVLA checkpoint")
|
logging.info("Loaded XVLA checkpoint")
|
||||||
# step 5: finalize
|
# step 5: finalize
|
||||||
# Reapply dtype after loading state dict
|
# Reapply dtype after loading state dict
|
||||||
instance.model._apply_dtype()
|
instance.model._apply_dtype()
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ def make_xvla_pre_post_processors(
|
|||||||
),
|
),
|
||||||
XVLAImageToFloatProcessorStep(),
|
XVLAImageToFloatProcessorStep(),
|
||||||
XVLAImageNetNormalizeProcessorStep(),
|
XVLAImageNetNormalizeProcessorStep(),
|
||||||
DeviceProcessorStep(device=config.device),
|
|
||||||
XVLAAddDomainIdProcessorStep(),
|
XVLAAddDomainIdProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
),
|
),
|
||||||
@@ -426,11 +426,9 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
domain_id: The domain ID to add (default: 3)
|
domain_id: The domain ID to add (default: 3)
|
||||||
device: Device to place the domain_id tensor on (default: "cuda")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
domain_id: int = 0
|
domain_id: int = 0
|
||||||
device: str = "cuda"
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""Add domain_id to complementary data."""
|
"""Add domain_id to complementary data."""
|
||||||
@@ -448,7 +446,7 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Add domain_id tensor
|
# Add domain_id tensor
|
||||||
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long).to(self.device)
|
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
|
||||||
|
|
||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||||
return new_transition
|
return new_transition
|
||||||
@@ -461,7 +459,6 @@ class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
|||||||
"""Return serializable configuration."""
|
"""Return serializable configuration."""
|
||||||
return {
|
return {
|
||||||
"domain_id": self.domain_id,
|
"domain_id": self.domain_id,
|
||||||
"device": self.device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user