diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 0f8afafe4..0aec74811 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -428,6 +428,8 @@ def make_policy( else: # Make a fresh policy. policy = policy_cls(**kwargs) + kwargs["pretrained_name_or_path"] = "/fsx/jade_choghari/.cache/huggingface/model/xvla-libero" + policy = policy_cls.from_pretrained(**kwargs) policy.to(cfg.device) assert isinstance(policy, torch.nn.Module) diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index ef3afa465..ed3badc75 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -70,7 +70,7 @@ class XVLAConfig(PreTrainedConfig): num_domains: int = 30 len_soft_prompts: int = 32 dim_time: int = 32 - max_len_seq: int = 512 + max_len_seq: int = 512 #TODO: jadechoghari: change to 512 1024 use_hetero_proj: bool = False # Action & proprioception diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index b973e713c..1e3db3517 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -34,6 +34,7 @@ from .configuration_xvla import XVLAConfig from .modeling_florence2 import Florence2ForConditionalGeneration from .transformer import SoftPromptedTransformer +import os class XVLAModel(nn.Module): """ @@ -94,6 +95,11 @@ class XVLAModel(nn.Module): flat_mask = image_mask.view(-1).to(dtype=torch.bool) flat_images = pixel_values.flatten(0, 1) + #TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this + target_size = (224, 224) + flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False) + + num_valid = int(flat_mask.sum().item()) if num_valid == 0: raise ValueError("At least one image view must be valid per batch.") @@ -105,7 +111,6 @@ class XVLAModel(nn.Module): image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim)) image_features[flat_mask] = valid_feats image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim) - inputs_embeds = self.vlm.get_input_embeddings()(input_ids) merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features( image_features[:, 0], @@ -337,6 +342,99 @@ class XVLAPolicy(PreTrainedPolicy): self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) return self._queues[ACTION].popleft() + + @classmethod + def from_pretrained( + cls, + pretrained_name_or_path: str | Path, + *, + config: "PreTrainedConfig" | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = False, + **kwargs, + ): + """ + Loads XVLA model weights with: + - automatic prefix 'model.' added to all keys + - skip list for layers that should remain randomly initialized + """ + import safetensors.torch + # --- Step 1: Load config --- + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + model_id = str(pretrained_name_or_path) + instance = cls(config, **kwargs) + + # --- Step 2: Locate model.safetensors --- + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, "model.safetensors") + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename="model.safetensors", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"model.safetensors not found on the Hub at {model_id}" + ) from e + + # --- Step 3: Load safetensor weights --- + print(f"Loading checkpoint from {model_file}") + state_dict = safetensors.torch.load_file(model_file) + + # --- Step 4: Modify keys --- + new_state_dict = {f"model.{k}": v for k, v in state_dict.items()} + + # Layers to skip (reinitialize) + keys_to_skip = [ + "model.transformer.action_encoder.fc.weight", + "model.transformer.action_encoder.fc.bias", + "model.transformer.action_decoder.fc.weight", + "model.transformer.action_decoder.bias.weight" + ] + new_state_dict = { + k: v for k, v in new_state_dict.items() + if k not in keys_to_skip + } + + # --- Step 5: Load into instance --- + missing, unexpected = instance.load_state_dict(new_state_dict, strict=False) + print("✅ Loaded XVLA checkpoint with modified keys.") + if missing: + print(f"Missing keys: {missing}") + if unexpected: + print(f"Unexpected keys: {unexpected}") + + # --- Step 6: Finalize --- + instance.to(config.device) + instance.eval() + return instance def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0cc6e037f..607fc7bc9 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -51,6 +51,8 @@ from lerobot.utils.utils import ( init_logging, ) +# login to hf + def update_policy( train_metrics: MetricsTracker, diff --git a/test_xvla.py b/test_xvla.py index 1d82be009..6b8ab182d 100644 --- a/test_xvla.py +++ b/test_xvla.py @@ -16,7 +16,7 @@ for name, param in policy.state_dict().items(): import safetensors.torch from huggingface_hub import snapshot_download -cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model") +cache_dir = snapshot_download(repo_id="2toINF/X-VLA-Libero", repo_type="model", cache_dir="/fsx/jade_choghari/.cache/huggingface/model") state_dict = safetensors.torch.load_file(os.path.join(cache_dir, "model.safetensors")) # policy.load_state_dict(state_dict) # 3. Add "model." prefix to every key diff --git a/train.sh b/train.sh index 4683936ae..5b73de847 100644 --- a/train.sh +++ b/train.sh @@ -1,10 +1,13 @@ lerobot-train \ - --dataset.repo_id=lerobot/svla_so101_pickplace \ + --dataset.repo_id=libero_dataset \ + --dataset.root=/fsx/jade_choghari/datasets/libero/ \ --policy.type=xvla \ - --output_dir=outputs/train/act_your_dataset \ - --job_name=xvla_so101_pickplace \ + --output_dir=/fsx/jade_choghari/outputs/train/xvla_libero \ + --job_name=xvla_libero \ --policy.device=cuda \ --policy.action_mode=franka_joint7 \ --wandb.enable=true \ - --policy.repo_id=jadechoghari/xvla_policy \ + --policy.repo_id=jadechoghari/X-VLA-Libero \ --steps=10000 + +# # --policy.pretrained_path=/fsx/jade_choghari/.cache/huggingface/model/xvla-libero \ \ No newline at end of file diff --git a/train_multi.sh b/train_multi.sh new file mode 100644 index 000000000..fa3668d75 --- /dev/null +++ b/train_multi.sh @@ -0,0 +1,18 @@ +accelerate launch \ + --multi_gpu \ + --num_processes=4 \ + --mixed_precision=fp16 \ + $(which lerobot-train) \ + --batch_size=32 \ + --save_freq=5000 \ + --num_workers=32 \ + --dataset.repo_id=libero_dataset \ + --dataset.root=/fsx/jade_choghari/datasets/libero/ \ + --policy.type=xvla \ + --output_dir=/fsx/jade_choghari/outputs/train/xvla_libero_multi \ + --job_name=xvla_libero \ + --policy.device=cuda \ + --policy.action_mode=franka_joint7 \ + --wandb.enable=true \ + --policy.repo_id=jadechoghari/X-VLA-Libero \ + --steps=10000