diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 0445d6c00..58b5dc07b 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 11d8b4d68..104ec63bf 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0.5 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index a1499d077..e730b78a7 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -13,6 +13,7 @@ # limitations under the License. import abc import builtins +import dataclasses import logging import os from importlib.resources import files @@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): card = ModelCard.from_template(card_data, template_str=template_card) card.validate() return card + + def wrap_with_peft( + self, + peft_config=None, + peft_cli_overrides: dict | None = None, + ) -> "PreTrainedPolicy": + """ + Wrap this policy with PEFT adapters for parameter-efficient fine-tuning. + + This method is the single entry point for PEFT integration. Subclasses should + override `_get_default_peft_targets()` to provide default target modules, and + `_validate_peft_config()` for policy-specific validation. + + Args: + peft_config: Optional PEFT adapter configuration (e.g., LoraConfig). + If provided, used directly (with CLI overrides applied). + peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.) + These are merged with policy defaults to build the final config. + """ + from peft import get_peft_model + + # If user provided a complete config, use it directly (with overrides) + if peft_config is not None: + final_config = peft_config + if peft_cli_overrides: + final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides) + else: + # Build config from defaults + CLI overrides + final_config = self._build_peft_config(peft_cli_overrides or {}) + + # Validate the configuration + self._validate_peft_config(final_config) + + # Freeze base parameters, only adapter params will be trained + for p in self.parameters(): + p.requires_grad_(False) + + # Store pretrained path for PEFT's base_model_name_or_path + if self.config.pretrained_path: + self.name_or_path = str(self.config.pretrained_path) + + # Wrap with PEFT + peft_model = get_peft_model(self, final_config) + + # Mark config as using PEFT for proper loading later + peft_model.config.use_peft = True + + logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})") + return peft_model + + def _get_default_peft_targets(self) -> dict[str, any] | None: + """ + Return default PEFT target modules for this policy. + + Override this in subclasses to provide policy-specific defaults. These defaults + are PEFT-method agnostic - they only specify which modules to target. + + """ + return None + + def _validate_peft_config(self, peft_config) -> None: + """ + Validate the PEFT configuration for this policy. + + Override this in subclasses to add policy-specific validation or warnings. + The default implementation checks that a pretrained_path exists. + + Args: + peft_config: The PEFT configuration to validate. + + Raises: + ValueError: If the configuration is invalid. + """ + if not self.config.pretrained_path: + raise ValueError( + "Training from scratch using PEFT is unlikely to yield good results. " + "Supply a `policy.pretrained_path` to fine-tune an existing model." + ) + + def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict: + """ + Preprocess CLI overrides: rename keys and handle method-specific init_type. + + Args: + cli_overrides: Dict of CLI options (will be copied, not mutated). + peft_method_type: The PeftType enum value for the PEFT method. + + Returns: + Preprocessed dict with renamed keys and init_type mapped to method-specific key. + """ + from peft import PeftType + + cli_overrides = cli_overrides.copy() + + # Handle the full_training_modules -> modules_to_save rename + if "full_training_modules" in cli_overrides: + cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules") + + # Remove method_type as it's handled separately + cli_overrides.pop("method_type", None) + + # Handle init_type specially based on PEFT method + init_type = cli_overrides.pop("init_type", None) + if init_type is not None: + if peft_method_type == PeftType.LORA: + cli_overrides["init_lora_weights"] = init_type + elif peft_method_type == PeftType.MISS: + cli_overrides["init_weights"] = init_type + else: + raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.") + + return cli_overrides + + def _build_peft_config(self, cli_overrides: dict): + """Build a PEFT config from policy defaults and CLI overrides.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Determine PEFT method type (default to LORA) + method_type_str = cli_overrides.get("method_type") or "lora" + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with policy defaults, apply CLI overrides + config_dict = dict(self._get_default_peft_targets() or {}) + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + # Ensure we have target_modules + if not config_dict.get("target_modules"): + raise ValueError( + f"Policy '{self.name}' does not define default target_modules. " + "Please pass --peft.target_modules explicitly." + ) + + return peft_config_cls(**config_dict) + + def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict): + """Apply CLI overrides to an existing PEFT config.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Get method type from existing config or CLI override + method_type_str = cli_overrides.get("method_type") + if method_type_str: + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + else: + peft_method_type = PeftType(peft_config.peft_type) + peft_config_cls = type(peft_config) + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with existing config, apply CLI overrides + config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")} + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + return peft_config_cls(**config_dict) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index f998661f9..c611e9ba2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy): actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for SmolVLA fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } + + def _validate_peft_config(self, peft_config) -> None: + """Validate PEFT configuration for SmolVLA.""" + super()._validate_peft_config(peft_config) + if not self.config.load_vlm_weights: + import logging + + logging.warning( + "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. " + "Set `load_vlm_weights=True` to fine-tune the existing policy." + ) + def pad_tensor(tensor, max_len, pad_value=0): """ diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 55737d5d8..41f866ce8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -148,92 +148,6 @@ def update_policy( return train_metrics, output_dict -def get_default_peft_configuration(policy_type): - """Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint.""" - - common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" - - if policy_type == "smolvla": - return { - "target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - elif policy_type in ("pi0", "pi05"): - return { - "target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - - return {"modules_to_save": None} - - -def wrap_policy_in_peft_model(cfg, policy): - from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model - - # Disable all gradients because we'll only train the parameters selected by the PEFT method. - # Layers that should receive gradients anyway need to be listed in `modules_to_save`. - for p in policy.parameters(): - p.requires_grad_(False) - - if not cfg.policy.pretrained_path: - raise ValueError( - "Training from scratch using PEFT. This is unlikely to yield good results. " - "Supply a `policy.path` to fine-tune an existing model." - ) - - if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: - logging.warning( - "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " - "`load_vlm_weights=True` to fine-tune the existing policy." - ) - - peft_config_policy = get_default_peft_configuration(cfg.policy.type) - peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} - peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT - peft_method_type = PeftType[peft_config_cli["method_type"].upper()] - peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] - - # Handle specific CLI overrides - for key in ["target_modules", "modules_to_save", "r"]: - if peft_config_cli[key] is not None: - peft_config_policy[key] = peft_config_cli[key] - - if "target_modules" not in peft_config_policy: - raise ValueError( - f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." - ) - - # Init method depends on the used PEFT method, your specific PEFT method - # might not be considered here, in that case an error is raised. - if peft_config_cli["init_type"] is not None: - if peft_method_type == "LORA": - peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] - elif peft_method_type == "MISS": - peft_config_policy["init_weights"] = peft_config_cli["init_type"] - else: - raise ValueError( - f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." - ) - - # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the - # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the - # adapter, not the base model. - if policy.config.pretrained_path: - policy.name_or_path = str(policy.config.pretrained_path) - - # Finally wrap the policy in a PEFT model - policy = get_peft_model( - policy, - peft_config_cls(**peft_config_policy), - ) - - # Make sure that the config is tagged as using PEFT so that the loading code can take the - # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. - policy.config.use_peft = True - - return policy - - @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -326,7 +240,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.peft is not None: logging.info("Using PEFT! Wrapping model.") - policy = wrap_policy_in_peft_model(cfg, policy) + # Convert CLI peft config to dict for overrides + peft_cli_overrides = dataclasses.asdict(cfg.peft) + policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone()