diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 5be3bca43..754e0c857 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -514,7 +514,7 @@ def make_policy( logging.info("Loading policy's PEFT adapter.") - peft_pretrained_path = cfg.pretrained_path + peft_pretrained_path = str(cfg.pretrained_path) peft_config = PeftConfig.from_pretrained(peft_pretrained_path) kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path @@ -527,7 +527,9 @@ def make_policy( ) policy = policy_cls.from_pretrained(**kwargs) - policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + policy = PeftModel.from_pretrained( + policy, peft_pretrained_path, config=peft_config, is_trainable=True + ) else: # Make a fresh policy. diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9d7330e6a..55a8cc935 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -277,9 +277,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if cfg.peft is not None: if cfg.is_reward_model_training: raise ValueError("PEFT is only supported for policy training. ") - logging.info("Using PEFT! Wrapping model.") - peft_cli_overrides = dataclasses.asdict(cfg.peft) - policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) + from peft import PeftModel + + if isinstance(policy, PeftModel): + logging.info("PEFT adapter already loaded from checkpoint, skipping wrap_with_peft.") + else: + logging.info("Using PEFT! Wrapping model.") + peft_cli_overrides = dataclasses.asdict(cfg.peft) + policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) # Wait for all processes to finish model creation before continuing accelerator.wait_for_everyone()