From 8d9a9929533d334d32efe12bb709ed947c53e27e Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 10 Nov 2025 13:17:47 +0100 Subject: [PATCH] update testing script --- .../policies/xvla/configuration_xvla.py | 33 ++++++++++++------- test_xvla.py | 31 +++++++++++++++-- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index 2658b87b0..ef3afa465 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -129,33 +129,44 @@ class XVLAConfig(PreTrainedConfig): config_dict["vision_config"] = { "model_type": "davit", "drop_path_rate": 0.1, - "patch_size": [14, 7, 7, 7], + "patch_size": [7, 3, 3, 3], "patch_stride": [4, 2, 2, 2], "patch_padding": [3, 1, 1, 1], "patch_prenorm": [False, True, True, True], + "enable_checkpoint": False, "dim_embed": [256, 512, 1024, 2048], "num_heads": [8, 16, 32, 64], "num_groups": [8, 16, 32, 64], "depths": [1, 1, 9, 1], "window_size": 12, "projection_dim": 1024, - "visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100}, - "image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50}, - "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"], - } + "visual_temporal_embedding": { + "type": "COSINE", + "max_temporal_embeddings": 100 + }, + "image_pos_embed": { + "type": "learned_abs_2d", + "max_pos_embeddings": 50 + }, + "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"] + } if "text_config" not in config_dict or config_dict["text_config"] is None: # Provide default text config config_dict["text_config"] = { - "model_type": "florence2_language", "vocab_size": 51289, + "activation_dropout": 0.1, + "activation_function": "gelu", + "attention_dropout": 0.1, "d_model": 1024, - "encoder_layers": 12, + "decoder_attention_heads": 16, "decoder_layers": 12, "encoder_attention_heads": 16, - "decoder_attention_heads": 16, - "encoder_ffn_dim": 4096, - "decoder_ffn_dim": 4096, - } + "encoder_layers": 12, + "dropout": 0.1, + "max_position_embeddings": 4096, + "num_hidden_layers": 12, + "num_beams": 3 + } self._florence_config_obj = Florence2Config(**config_dict) return self._florence_config_obj diff --git a/test_xvla.py b/test_xvla.py index 5cf8817f3..1d82be009 100644 --- a/test_xvla.py +++ b/test_xvla.py @@ -1,10 +1,37 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy, make_policy_config - +import os cfg = make_policy_config("xvla") dataset_id = "lerobot/svla_so101_pickplace" # This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets dataset_metadata = LeRobotDatasetMetadata(dataset_id) policy = make_policy(cfg=cfg, ds_meta=dataset_metadata) -print(policy) + +for name, param in policy.state_dict().items(): + print(name, param.shape) + + +# now let's load in safetensors +import safetensors.torch +from huggingface_hub import snapshot_download + +cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model") +state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors")) +# policy.load_state_dict(state_dict) +# 3. Add "model." prefix to every key +new_state_dict = {f"model.{k}": v for k, v in state_dict.items()} +keys_to_skip = [ + "model.transformer.action_encoder.fc.weight", + "model.transformer.action_encoder.fc.bias", +] + +new_state_dict = {k: v for k, v in new_state_dict.items() if k not in keys_to_skip} +# 4. Load into your model +missing, unexpected = policy.load_state_dict(new_state_dict, strict=False) + +print("missing keys:", missing) + +print() +print("unexpected keys:", unexpected) +