mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
update testing script
This commit is contained in:
@@ -129,33 +129,44 @@ class XVLAConfig(PreTrainedConfig):
|
|||||||
config_dict["vision_config"] = {
|
config_dict["vision_config"] = {
|
||||||
"model_type": "davit",
|
"model_type": "davit",
|
||||||
"drop_path_rate": 0.1,
|
"drop_path_rate": 0.1,
|
||||||
"patch_size": [14, 7, 7, 7],
|
"patch_size": [7, 3, 3, 3],
|
||||||
"patch_stride": [4, 2, 2, 2],
|
"patch_stride": [4, 2, 2, 2],
|
||||||
"patch_padding": [3, 1, 1, 1],
|
"patch_padding": [3, 1, 1, 1],
|
||||||
"patch_prenorm": [False, True, True, True],
|
"patch_prenorm": [False, True, True, True],
|
||||||
|
"enable_checkpoint": False,
|
||||||
"dim_embed": [256, 512, 1024, 2048],
|
"dim_embed": [256, 512, 1024, 2048],
|
||||||
"num_heads": [8, 16, 32, 64],
|
"num_heads": [8, 16, 32, 64],
|
||||||
"num_groups": [8, 16, 32, 64],
|
"num_groups": [8, 16, 32, 64],
|
||||||
"depths": [1, 1, 9, 1],
|
"depths": [1, 1, 9, 1],
|
||||||
"window_size": 12,
|
"window_size": 12,
|
||||||
"projection_dim": 1024,
|
"projection_dim": 1024,
|
||||||
"visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100},
|
"visual_temporal_embedding": {
|
||||||
"image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50},
|
"type": "COSINE",
|
||||||
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"],
|
"max_temporal_embeddings": 100
|
||||||
}
|
},
|
||||||
|
"image_pos_embed": {
|
||||||
|
"type": "learned_abs_2d",
|
||||||
|
"max_pos_embeddings": 50
|
||||||
|
},
|
||||||
|
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"]
|
||||||
|
}
|
||||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||||
# Provide default text config
|
# Provide default text config
|
||||||
config_dict["text_config"] = {
|
config_dict["text_config"] = {
|
||||||
"model_type": "florence2_language",
|
|
||||||
"vocab_size": 51289,
|
"vocab_size": 51289,
|
||||||
|
"activation_dropout": 0.1,
|
||||||
|
"activation_function": "gelu",
|
||||||
|
"attention_dropout": 0.1,
|
||||||
"d_model": 1024,
|
"d_model": 1024,
|
||||||
"encoder_layers": 12,
|
"decoder_attention_heads": 16,
|
||||||
"decoder_layers": 12,
|
"decoder_layers": 12,
|
||||||
"encoder_attention_heads": 16,
|
"encoder_attention_heads": 16,
|
||||||
"decoder_attention_heads": 16,
|
"encoder_layers": 12,
|
||||||
"encoder_ffn_dim": 4096,
|
"dropout": 0.1,
|
||||||
"decoder_ffn_dim": 4096,
|
"max_position_embeddings": 4096,
|
||||||
}
|
"num_hidden_layers": 12,
|
||||||
|
"num_beams": 3
|
||||||
|
}
|
||||||
self._florence_config_obj = Florence2Config(**config_dict)
|
self._florence_config_obj = Florence2Config(**config_dict)
|
||||||
return self._florence_config_obj
|
return self._florence_config_obj
|
||||||
|
|
||||||
|
|||||||
+29
-2
@@ -1,10 +1,37 @@
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
from lerobot.policies.factory import make_policy, make_policy_config
|
from lerobot.policies.factory import make_policy, make_policy_config
|
||||||
|
import os
|
||||||
cfg = make_policy_config("xvla")
|
cfg = make_policy_config("xvla")
|
||||||
|
|
||||||
dataset_id = "lerobot/svla_so101_pickplace"
|
dataset_id = "lerobot/svla_so101_pickplace"
|
||||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||||
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
|
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
|
||||||
print(policy)
|
|
||||||
|
for name, param in policy.state_dict().items():
|
||||||
|
print(name, param.shape)
|
||||||
|
|
||||||
|
|
||||||
|
# now let's load in safetensors
|
||||||
|
import safetensors.torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model")
|
||||||
|
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
|
||||||
|
# policy.load_state_dict(state_dict)
|
||||||
|
# 3. Add "model." prefix to every key
|
||||||
|
new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||||
|
keys_to_skip = [
|
||||||
|
"model.transformer.action_encoder.fc.weight",
|
||||||
|
"model.transformer.action_encoder.fc.bias",
|
||||||
|
]
|
||||||
|
|
||||||
|
new_state_dict = {k: v for k, v in new_state_dict.items() if k not in keys_to_skip}
|
||||||
|
# 4. Load into your model
|
||||||
|
missing, unexpected = policy.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
|
print("missing keys:", missing)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("unexpected keys:", unexpected)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user