mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Revamp pretrained model loading
There were quite a few factors that convinced me that the status quo is able to load pretrained models from the PEFT adapter config but in fact that didn't work. This commit fixes the following things: - policies wrapped in PEFT will now have a `name_or_path` attribute containing the name or path of the pretrained model we're fine-tuning - we further assume that SmolVLA without `pretrained_path` and `load_vlm_weights==False` must be an user-side error - we assume that using PEFT on from-scratch-policies must be an user-side-error
This commit is contained in:
@@ -406,27 +406,44 @@ def make_policy(
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if not cfg.pretrained_path and cfg.use_peft:
|
||||
raise ValueError(
|
||||
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
|
||||
"the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that."
|
||||
)
|
||||
|
||||
if cfg.pretrained_path and not cfg.use_peft:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
# of the adapter and the adapter's config contains the path to the base policy. So we need the
|
||||
# adapter config first, then load the correct policy and then apply PEFT.
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
if not kwargs["pretrained_name_or_path"]:
|
||||
# This means that there's a bug or we trained a policy from scratch using PEFT.
|
||||
# It is more likely that this is a bug so we'll raise an error.
|
||||
raise ValueError(
|
||||
"No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which "
|
||||
"the adapter was trained."
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
if cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. This requires that the policy was instantiated from
|
||||
# scratch since PEFT is handling base model loading via the adapter config.
|
||||
from peft import PeftModel
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
policy = PeftModel.from_pretrained(policy, cfg.pretrained_path)
|
||||
elif cfg.use_peft:
|
||||
raise ValueError(
|
||||
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
|
||||
"the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that."
|
||||
)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
|
||||
@@ -173,7 +173,7 @@ def wrap_policy_in_peft_model(cfg, policy):
|
||||
"Consider supplying a `policy.path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
if cfg.policy.type == "smolvla" and not cfg.load_vlm_weights:
|
||||
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
||||
logging.warning(
|
||||
"Traning SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
|
||||
"`load_vlm_weights=True` to fine-tune the existing policy."
|
||||
@@ -207,6 +207,12 @@ def wrap_policy_in_peft_model(cfg, policy):
|
||||
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
|
||||
)
|
||||
|
||||
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
||||
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
||||
# adapter, not the base model.
|
||||
policy.name_or_path = str(policy.config.pretrained_path)
|
||||
|
||||
# Finally wrap the policy in a PEFT model
|
||||
policy = get_peft_model(
|
||||
policy,
|
||||
peft_config_cls(**peft_config_policy),
|
||||
|
||||
Reference in New Issue
Block a user