add changes

This commit is contained in:
jade.choghari@huggingface.co
2025-11-10 14:53:17 +00:00
parent 8d9a992953
commit 2219c29690
7 changed files with 130 additions and 7 deletions
+2
View File
@@ -428,6 +428,8 @@ def make_policy(
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
kwargs["pretrained_name_or_path"] = "/fsx/jade_choghari/.cache/huggingface/model/xvla-libero"
policy = policy_cls.from_pretrained(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module)
@@ -70,7 +70,7 @@ class XVLAConfig(PreTrainedConfig):
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512
max_len_seq: int = 512 #TODO: jadechoghari: change to 512 1024
use_hetero_proj: bool = False
# Action & proprioception
+99 -1
View File
@@ -34,6 +34,7 @@ from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .transformer import SoftPromptedTransformer
import os
class XVLAModel(nn.Module):
"""
@@ -94,6 +95,11 @@ class XVLAModel(nn.Module):
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
#TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this
target_size = (224, 224)
flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
@@ -105,7 +111,6 @@ class XVLAModel(nn.Module):
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
image_features[flat_mask] = valid_feats
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
image_features[:, 0],
@@ -337,6 +342,99 @@ class XVLAPolicy(PreTrainedPolicy):
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls,
pretrained_name_or_path: str | Path,
*,
config: "PreTrainedConfig" | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
):
"""
Loads XVLA model weights with:
- automatic prefix 'model.' added to all keys
- skip list for layers that should remain randomly initialized
"""
import safetensors.torch
# --- Step 1: Load config ---
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
# --- Step 2: Locate model.safetensors ---
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors")
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename="model.safetensors",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"model.safetensors not found on the Hub at {model_id}"
) from e
# --- Step 3: Load safetensor weights ---
print(f"Loading checkpoint from {model_file}")
state_dict = safetensors.torch.load_file(model_file)
# --- Step 4: Modify keys ---
new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
# Layers to skip (reinitialize)
keys_to_skip = [
"model.transformer.action_encoder.fc.weight",
"model.transformer.action_encoder.fc.bias",
"model.transformer.action_decoder.fc.weight",
"model.transformer.action_decoder.bias.weight"
]
new_state_dict = {
k: v for k, v in new_state_dict.items()
if k not in keys_to_skip
}
# --- Step 5: Load into instance ---
missing, unexpected = instance.load_state_dict(new_state_dict, strict=False)
print("✅ Loaded XVLA checkpoint with modified keys.")
if missing:
print(f"Missing keys: {missing}")
if unexpected:
print(f"Unexpected keys: {unexpected}")
# --- Step 6: Finalize ---
instance.to(config.device)
instance.eval()
return instance
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
+2
View File
@@ -51,6 +51,8 @@ from lerobot.utils.utils import (
init_logging,
)
# login to hf
def update_policy(
train_metrics: MetricsTracker,
+1 -1
View File
@@ -16,7 +16,7 @@ for name, param in policy.state_dict().items():
import safetensors.torch
from huggingface_hub import snapshot_download
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model")
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model")
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
# policy.load_state_dict(state_dict)
# 3. Add "model." prefix to every key
+7 -4
View File
@@ -1,10 +1,13 @@
lerobot-train \
--dataset.repo_id=lerobot/svla_so101_pickplace \
--dataset.repo_id=libero_dataset \
--dataset.root=/fsx/jade_choghari/datasets/libero/ \
--policy.type=xvla \
--output_dir=outputs/train/act_your_dataset \
--job_name=xvla_so101_pickplace \
--output_dir=/fsx/jade_choghari/outputs/train/xvla_libero \
--job_name=xvla_libero \
--policy.device=cuda \
--policy.action_mode=franka_joint7 \
--wandb.enable=true \
--policy.repo_id=jadechoghari/xvla_policy \
--policy.repo_id=jadechoghari/X-VLA-Libero \
--steps=10000
# # --policy.pretrained_path=/fsx/jade_choghari/.cache/huggingface/model/xvla-libero \
+18
View File
@@ -0,0 +1,18 @@
accelerate launch \
--multi_gpu \
--num_processes=4 \
--mixed_precision=fp16 \
$(which lerobot-train) \
--batch_size=32 \
--save_freq=5000 \
--num_workers=32 \
--dataset.repo_id=libero_dataset \
--dataset.root=/fsx/jade_choghari/datasets/libero/ \
--policy.type=xvla \
--output_dir=/fsx/jade_choghari/outputs/train/xvla_libero_multi \
--job_name=xvla_libero \
--policy.device=cuda \
--policy.action_mode=franka_joint7 \
--wandb.enable=true \
--policy.repo_id=jadechoghari/X-VLA-Libero \
--steps=10000