From dc67b2ff3fb4c65456dd8362a726247134423107 Mon Sep 17 00:00:00 2001 From: nemo Date: Sun, 22 Jun 2025 19:46:10 +0200 Subject: [PATCH] Store policy config alongside PEFT checkpoint Before this change the PEFT-wrapped policy did not save the policy's config alongside the adapter config / weights which prevented us from changing the policy config. Now the policy config is saved both in full training and PEFT training. This change makes loading the PEFT policy adapter much easier as well. --- lerobot/common/utils/train_utils.py | 4 +++ lerobot/record.py | 42 +++++++---------------------- tests/utils/test_train_utils.py | 14 ++++++++++ 3 files changed, 27 insertions(+), 33 deletions(-) diff --git a/lerobot/common/utils/train_utils.py b/lerobot/common/utils/train_utils.py index a79983128..6848e7db6 100644 --- a/lerobot/common/utils/train_utils.py +++ b/lerobot/common/utils/train_utils.py @@ -99,6 +99,10 @@ def save_checkpoint( pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if cfg.use_peft: + # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the + # policy config which we need for loading the model. In this case we'll write it ourselves. + policy.config.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler) diff --git a/lerobot/record.py b/lerobot/record.py index 4f3c32c3a..256af2f37 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -44,10 +44,6 @@ from pprint import pformat import numpy as np import rerun as rr -from peft import PeftConfig, PeftModel -import importlib - - from lerobot.common.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) @@ -130,21 +126,6 @@ class DatasetRecordConfig: raise ValueError("You need to provide a task as argument in `single_task`.") -def get_policy_config_from_peft_checkpoint(peft_config): - if getattr(peft_config, "auto_mapping", None) is None: - raise ValueError( - "No auto-mapping config found in adapter config. Cannot determine policy config." - ) - - auto_mapping = getattr(peft_config, "auto_mapping", None) - base_model_class = auto_mapping["base_model_class"] - parent_library_name = auto_mapping["parent_library"] - - parent_library = importlib.import_module(parent_library_name) - target_class = getattr(parent_library, base_model_class) - return target_class.config_class - - @dataclass class RecordConfig: robot: RobotConfig @@ -167,23 +148,15 @@ class RecordConfig: if policy_path: cli_overrides = parser.get_cli_overrides("policy") + # In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the + # adapter parameters / config. If they didn't we could instantiate the default configuration for the policy + # but we wouldn't know if that is correct. So, in case of a missing config this will simply fail. + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if (Path(policy_path) / 'adapter_config.json').exists(): - # The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the - # policy's config alongside the adapter config but to initialize the policy we - # need a policy config. We assume that the config hasn't changed and we infer - # the policy's config class from the base class mentioned in the adapter config. - self.peft_config = PeftConfig.from_pretrained(policy_path) - - policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config) - - self.policy = policy_config_class() - self.policy.pretrained_path = policy_path self.policy.use_peft = True - else: - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - if self.teleop is None and self.policy is None: raise ValueError("Choose a policy, a teleoperator or both to control the robot") @@ -314,6 +287,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy if cfg.policy.use_peft: + from peft import PeftModel + + logging.info("Loading policy's PEFT adapter.") # in case of PEFT we re-use the policy pretrained path to point to the adapter path. peft_path = cfg.policy.pretrained_path cfg.policy.pretrained_path = None diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index b78f6e497..198c5c4ea 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -79,6 +79,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): mock_save_training_state.assert_called_once() +@patch("lerobot.common.utils.train_utils.save_training_state") +def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): + policy = Mock() + policy.config = Mock() + policy.config.save_pretrained = Mock() + cfg = Mock() + cfg.use_peft = True + save_checkpoint(tmp_path, 10, cfg, policy, optimizer) + policy.save_pretrained.assert_called_once() + cfg.save_pretrained.assert_called_once() + policy.config.save_pretrained.assert_called_once() + mock_save_training_state.assert_called_once() + + def test_save_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) assert (tmp_path / TRAINING_STATE_DIR).is_dir()