diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 605f518ea..e5bc45698 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -406,27 +406,44 @@ def make_policy( cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg + if not cfg.pretrained_path and cfg.use_peft: + raise ValueError( + "Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires " + "the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that." + ) + if cfg.pretrained_path and not cfg.use_peft: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). kwargs["pretrained_name_or_path"] = cfg.pretrained_path policy = policy_cls.from_pretrained(**kwargs) + elif cfg.pretrained_path and cfg.use_peft: + # Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo + # of the adapter and the adapter's config contains the path to the base policy. So we need the + # adapter config first, then load the correct policy and then apply PEFT. + from peft import PeftConfig, PeftModel + + logging.info("Loading policy's PEFT adapter.") + + peft_pretrained_path = cfg.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_pretrained_path) + + kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path + if not kwargs["pretrained_name_or_path"]: + # This means that there's a bug or we trained a policy from scratch using PEFT. + # It is more likely that this is a bug so we'll raise an error. + raise ValueError( + "No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which " + "the adapter was trained." + ) + + policy = policy_cls.from_pretrained(**kwargs) + policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + else: # Make a fresh policy. policy = policy_cls(**kwargs) - if cfg.pretrained_path and cfg.use_peft: - # Load a pretrained PEFT model on top of the policy. This requires that the policy was instantiated from - # scratch since PEFT is handling base model loading via the adapter config. - from peft import PeftModel - - logging.info("Loading policy's PEFT adapter.") - policy = PeftModel.from_pretrained(policy, cfg.pretrained_path) - elif cfg.use_peft: - raise ValueError( - "Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires " - "the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that." - ) policy.to(cfg.device) assert isinstance(policy, torch.nn.Module) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 82523af58..cd472614f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -173,7 +173,7 @@ def wrap_policy_in_peft_model(cfg, policy): "Consider supplying a `policy.path` to fine-tune an existing model." ) - if cfg.policy.type == "smolvla" and not cfg.load_vlm_weights: + if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: logging.warning( "Traning SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " "`load_vlm_weights=True` to fine-tune the existing policy." @@ -207,6 +207,12 @@ def wrap_policy_in_peft_model(cfg, policy): f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." ) + # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the + # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the + # adapter, not the base model. + policy.name_or_path = str(policy.config.pretrained_path) + + # Finally wrap the policy in a PEFT model policy = get_peft_model( policy, peft_config_cls(**peft_config_policy),