mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
add changes
This commit is contained in:
@@ -428,6 +428,8 @@ def make_policy(
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
kwargs["pretrained_name_or_path"] = "/fsx/jade_choghari/.cache/huggingface/model/xvla-libero"
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
|
||||
@@ -70,7 +70,7 @@ class XVLAConfig(PreTrainedConfig):
|
||||
num_domains: int = 30
|
||||
len_soft_prompts: int = 32
|
||||
dim_time: int = 32
|
||||
max_len_seq: int = 512
|
||||
max_len_seq: int = 512 #TODO: jadechoghari: change to 512 1024
|
||||
use_hetero_proj: bool = False
|
||||
|
||||
# Action & proprioception
|
||||
|
||||
@@ -34,6 +34,7 @@ from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .transformer import SoftPromptedTransformer
|
||||
|
||||
import os
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
@@ -94,6 +95,11 @@ class XVLAModel(nn.Module):
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
|
||||
#TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this
|
||||
target_size = (224, 224)
|
||||
flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
|
||||
|
||||
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
@@ -105,7 +111,6 @@ class XVLAModel(nn.Module):
|
||||
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||
image_features[flat_mask] = valid_feats
|
||||
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||
|
||||
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||
image_features[:, 0],
|
||||
@@ -337,6 +342,99 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: "PreTrainedConfig" | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads XVLA model weights with:
|
||||
- automatic prefix 'model.' added to all keys
|
||||
- skip list for layers that should remain randomly initialized
|
||||
"""
|
||||
import safetensors.torch
|
||||
# --- Step 1: Load config ---
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
|
||||
# --- Step 2: Locate model.safetensors ---
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="model.safetensors",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"model.safetensors not found on the Hub at {model_id}"
|
||||
) from e
|
||||
|
||||
# --- Step 3: Load safetensor weights ---
|
||||
print(f"Loading checkpoint from {model_file}")
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
|
||||
# --- Step 4: Modify keys ---
|
||||
new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
||||
|
||||
# Layers to skip (reinitialize)
|
||||
keys_to_skip = [
|
||||
"model.transformer.action_encoder.fc.weight",
|
||||
"model.transformer.action_encoder.fc.bias",
|
||||
"model.transformer.action_decoder.fc.weight",
|
||||
"model.transformer.action_decoder.bias.weight"
|
||||
]
|
||||
new_state_dict = {
|
||||
k: v for k, v in new_state_dict.items()
|
||||
if k not in keys_to_skip
|
||||
}
|
||||
|
||||
# --- Step 5: Load into instance ---
|
||||
missing, unexpected = instance.load_state_dict(new_state_dict, strict=False)
|
||||
print("✅ Loaded XVLA checkpoint with modified keys.")
|
||||
if missing:
|
||||
print(f"Missing keys: {missing}")
|
||||
if unexpected:
|
||||
print(f"Unexpected keys: {unexpected}")
|
||||
|
||||
# --- Step 6: Finalize ---
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
|
||||
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||
|
||||
@@ -51,6 +51,8 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
)
|
||||
|
||||
# login to hf
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ for name, param in policy.state_dict().items():
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model")
|
||||
cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model")
|
||||
state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors"))
|
||||
# policy.load_state_dict(state_dict)
|
||||
# 3. Add "model." prefix to every key
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot/svla_so101_pickplace \
|
||||
--dataset.repo_id=libero_dataset \
|
||||
--dataset.root=/fsx/jade_choghari/datasets/libero/ \
|
||||
--policy.type=xvla \
|
||||
--output_dir=outputs/train/act_your_dataset \
|
||||
--job_name=xvla_so101_pickplace \
|
||||
--output_dir=/fsx/jade_choghari/outputs/train/xvla_libero \
|
||||
--job_name=xvla_libero \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=franka_joint7 \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=jadechoghari/xvla_policy \
|
||||
--policy.repo_id=jadechoghari/X-VLA-Libero \
|
||||
--steps=10000
|
||||
|
||||
# # --policy.pretrained_path=/fsx/jade_choghari/.cache/huggingface/model/xvla-libero \
|
||||
@@ -0,0 +1,18 @@
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
--mixed_precision=fp16 \
|
||||
$(which lerobot-train) \
|
||||
--batch_size=32 \
|
||||
--save_freq=5000 \
|
||||
--num_workers=32 \
|
||||
--dataset.repo_id=libero_dataset \
|
||||
--dataset.root=/fsx/jade_choghari/datasets/libero/ \
|
||||
--policy.type=xvla \
|
||||
--output_dir=/fsx/jade_choghari/outputs/train/xvla_libero_multi \
|
||||
--job_name=xvla_libero \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=franka_joint7 \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=jadechoghari/X-VLA-Libero \
|
||||
--steps=10000
|
||||
Reference in New Issue
Block a user