Revamp pretrained model loading

There were quite a few factors that convinced me that the status quo
is able to load pretrained models from the PEFT adapter config but
in fact that didn't work.

This commit fixes the following things:
- policies wrapped in PEFT will now have a `name_or_path` attribute
  containing the name or path of the pretrained model we're fine-tuning
- we further assume that SmolVLA without `pretrained_path` and
  `load_vlm_weights==False` must be an user-side error
- we assume that using PEFT on from-scratch-policies must be
  an user-side-error
This commit is contained in:
nemo
2025-12-01 18:30:45 +01:00
parent e0b6aca97a
commit 75266860aa
2 changed files with 36 additions and 13 deletions
+29 -12
View File
@@ -406,27 +406,44 @@ def make_policy(
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if not cfg.pretrained_path and cfg.use_peft:
raise ValueError(
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
"the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that."
)
if cfg.pretrained_path and not cfg.use_peft:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
# of the adapter and the adapter's config contains the path to the base policy. So we need the
# adapter config first, then load the correct policy and then apply PEFT.
from peft import PeftConfig, PeftModel
logging.info("Loading policy's PEFT adapter.")
peft_pretrained_path = cfg.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
if not kwargs["pretrained_name_or_path"]:
# This means that there's a bug or we trained a policy from scratch using PEFT.
# It is more likely that this is a bug so we'll raise an error.
raise ValueError(
"No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which "
"the adapter was trained."
)
policy = policy_cls.from_pretrained(**kwargs)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
if cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. This requires that the policy was instantiated from
# scratch since PEFT is handling base model loading via the adapter config.
from peft import PeftModel
logging.info("Loading policy's PEFT adapter.")
policy = PeftModel.from_pretrained(policy, cfg.pretrained_path)
elif cfg.use_peft:
raise ValueError(
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
"the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that."
)
policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module)
+7 -1
View File
@@ -173,7 +173,7 @@ def wrap_policy_in_peft_model(cfg, policy):
"Consider supplying a `policy.path` to fine-tune an existing model."
)
if cfg.policy.type == "smolvla" and not cfg.load_vlm_weights:
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
logging.warning(
"Traning SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
"`load_vlm_weights=True` to fine-tune the existing policy."
@@ -207,6 +207,12 @@ def wrap_policy_in_peft_model(cfg, policy):
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
)
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
# adapter, not the base model.
policy.name_or_path = str(policy.config.pretrained_path)
# Finally wrap the policy in a PEFT model
policy = get_peft_model(
policy,
peft_config_cls(**peft_config_policy),