mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Clean up loading code
- Centralized instantiation of the PEFT wrapper in `make_policy` for inference (e.g. in `lerobot-record`) - Training a PEFT policy also sets `cfg.use_peft` so that all inference code loading the policy can rely on that attribute to identify if PEFT loading is needed - Modified RTC example to also include PEFT policies. Mostly because this is an example I'm currently exploring.
This commit is contained in:
@@ -455,6 +455,13 @@ def demo_cli(cfg: RTCDemoConfig):
|
|||||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||||
config.compile_model = cfg.use_torch_compile
|
config.compile_model = cfg.use_torch_compile
|
||||||
|
|
||||||
|
if config.use_peft:
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
policy = policy_class(config=config)
|
||||||
|
policy = PeftModel.from_pretrained(policy, cfg.policy.pretrained_path)
|
||||||
|
|
||||||
|
else:
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||||
|
|
||||||
# Turn on RTC
|
# Turn on RTC
|
||||||
|
|||||||
@@ -406,7 +406,7 @@ def make_policy(
|
|||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
|
|
||||||
if cfg.pretrained_path:
|
if cfg.pretrained_path and not cfg.use_peft:
|
||||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
# hyperparameters that we want to vary).
|
# hyperparameters that we want to vary).
|
||||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||||
@@ -415,6 +415,19 @@ def make_policy(
|
|||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
|
if cfg.pretrained_path and cfg.use_peft:
|
||||||
|
# Load a pretrained PEFT model on top of the policy. This requires that the policy was instantiated from
|
||||||
|
# scratch since PEFT is handling base model loading via the adapter config.
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
logging.info("Loading policy's PEFT adapter.")
|
||||||
|
policy = PeftModel.from_pretrained(policy, cfg.pretrained_path)
|
||||||
|
elif cfg.use_peft:
|
||||||
|
raise ValueError(
|
||||||
|
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
|
||||||
|
"the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that."
|
||||||
|
)
|
||||||
|
|
||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, torch.nn.Module)
|
assert isinstance(policy, torch.nn.Module)
|
||||||
|
|
||||||
|
|||||||
@@ -194,15 +194,9 @@ class RecordConfig:
|
|||||||
if policy_path:
|
if policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
|
|
||||||
# In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the
|
|
||||||
# adapter parameters / config. If they didn't we could instantiate the default configuration for the policy
|
|
||||||
# but we wouldn't know if that is correct. So, in case of a missing config this will simply fail.
|
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
if (Path(policy_path) / "adapter_config.json").exists():
|
|
||||||
self.policy.use_peft = True
|
|
||||||
|
|
||||||
if self.teleop is None and self.policy is None:
|
if self.teleop is None and self.policy is None:
|
||||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||||
|
|
||||||
@@ -433,18 +427,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
if cfg.policy and cfg.policy.use_peft:
|
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
logging.info("Loading policy's PEFT adapter.")
|
|
||||||
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
|
||||||
peft_path = cfg.policy.pretrained_path
|
|
||||||
cfg.policy.pretrained_path = None
|
|
||||||
|
|
||||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
||||||
policy = PeftModel.from_pretrained(policy, peft_path)
|
|
||||||
|
|
||||||
else:
|
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
preprocessor = None
|
preprocessor = None
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ def wrap_policy_in_peft_model(cfg, policy):
|
|||||||
if peft_config_cli["init_type"] is not None:
|
if peft_config_cli["init_type"] is not None:
|
||||||
if peft_method_type == "LORA":
|
if peft_method_type == "LORA":
|
||||||
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
||||||
elif peft_method_type == "BONE":
|
elif peft_method_type == "MISS":
|
||||||
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -200,6 +200,10 @@ def wrap_policy_in_peft_model(cfg, policy):
|
|||||||
peft_config_cls(**peft_config_policy),
|
peft_config_cls(**peft_config_policy),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make sure that the config is tagged as using PEFT so that the loading code can take the
|
||||||
|
# appropriate steps to use the adapter weights and the PEFT config instead of the full model weights.
|
||||||
|
policy.config.use_peft = True
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user