From e9b3889bd2ce993f75b4b9e7ed0b69c8e030df3d Mon Sep 17 00:00:00 2001 From: nemo Date: Mon, 24 Nov 2025 15:30:26 +0100 Subject: [PATCH] Clean up loading code - Centralized instantiation of the PEFT wrapper in `make_policy` for inference (e.g. in `lerobot-record`) - Training a PEFT policy also sets `cfg.use_peft` so that all inference code loading the policy can rely on that attribute to identify if PEFT loading is needed - Modified RTC example to also include PEFT policies. Mostly because this is an example I'm currently exploring. --- examples/rtc/eval_with_real_robot.py | 9 ++++++++- src/lerobot/policies/factory.py | 15 ++++++++++++++- src/lerobot/scripts/lerobot_record.py | 20 +------------------- src/lerobot/scripts/lerobot_train.py | 6 +++++- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 6f051485a..18eb821c4 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -455,7 +455,14 @@ def demo_cli(cfg: RTCDemoConfig): if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": config.compile_model = cfg.use_torch_compile - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) + if config.use_peft: + from peft import PeftModel + + policy = policy_class(config=config) + policy = PeftModel.from_pretrained(policy, cfg.policy.pretrained_path) + + else: + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) # Turn on RTC policy.config.rtc_config = cfg.rtc diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index eb6266757..605f518ea 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -406,7 +406,7 @@ def make_policy( cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg - if cfg.pretrained_path: + if cfg.pretrained_path and not cfg.use_peft: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). kwargs["pretrained_name_or_path"] = cfg.pretrained_path @@ -415,6 +415,19 @@ def make_policy( # Make a fresh policy. policy = policy_cls(**kwargs) + if cfg.pretrained_path and cfg.use_peft: + # Load a pretrained PEFT model on top of the policy. This requires that the policy was instantiated from + # scratch since PEFT is handling base model loading via the adapter config. + from peft import PeftModel + + logging.info("Loading policy's PEFT adapter.") + policy = PeftModel.from_pretrained(policy, cfg.pretrained_path) + elif cfg.use_peft: + raise ValueError( + "Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires " + "the PEFT config parameters to be set. For traning with PEFT, see `lerobot_train.py` on how to do that." + ) + policy.to(cfg.device) assert isinstance(policy, torch.nn.Module) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 28da73be2..d376c2c31 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -194,15 +194,9 @@ class RecordConfig: if policy_path: cli_overrides = parser.get_cli_overrides("policy") - # In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the - # adapter parameters / config. If they didn't we could instantiate the default configuration for the policy - # but we wouldn't know if that is correct. So, in case of a missing config this will simply fail. self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path - if (Path(policy_path) / "adapter_config.json").exists(): - self.policy.use_peft = True - if self.teleop is None and self.policy is None: raise ValueError("Choose a policy, a teleoperator or both to control the robot") @@ -433,19 +427,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: ) # Load pretrained policy - if cfg.policy and cfg.policy.use_peft: - from peft import PeftModel - - logging.info("Loading policy's PEFT adapter.") - # in case of PEFT we re-use the policy pretrained path to point to the adapter path. - peft_path = cfg.policy.pretrained_path - cfg.policy.pretrained_path = None - - policy = make_policy(cfg.policy, ds_meta=dataset.meta) - policy = PeftModel.from_pretrained(policy, peft_path) - - else: - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) preprocessor = None postprocessor = None diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index d3548b772..9c85e9dc0 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -188,7 +188,7 @@ def wrap_policy_in_peft_model(cfg, policy): if peft_config_cli["init_type"] is not None: if peft_method_type == "LORA": peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] - elif peft_method_type == "BONE": + elif peft_method_type == "MISS": peft_config_policy["init_weights"] = peft_config_cli["init_type"] else: raise ValueError( @@ -200,6 +200,10 @@ def wrap_policy_in_peft_model(cfg, policy): peft_config_cls(**peft_config_policy), ) + # Make sure that the config is tagged as using PEFT so that the loading code can take the + # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. + policy.config.use_peft = True + return policy