diff --git a/lerobot/common/utils/train_utils.py b/lerobot/common/utils/train_utils.py index a79983128..6848e7db6 100644 --- a/lerobot/common/utils/train_utils.py +++ b/lerobot/common/utils/train_utils.py @@ -99,6 +99,10 @@ def save_checkpoint( pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if cfg.use_peft: + # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the + # policy config which we need for loading the model. In this case we'll write it ourselves. + policy.config.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler) diff --git a/lerobot/record.py b/lerobot/record.py index 4f3c32c3a..256af2f37 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -44,10 +44,6 @@ from pprint import pformat import numpy as np import rerun as rr -from peft import PeftConfig, PeftModel -import importlib - - from lerobot.common.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) @@ -130,21 +126,6 @@ class DatasetRecordConfig: raise ValueError("You need to provide a task as argument in `single_task`.") -def get_policy_config_from_peft_checkpoint(peft_config): - if getattr(peft_config, "auto_mapping", None) is None: - raise ValueError( - "No auto-mapping config found in adapter config. Cannot determine policy config." - ) - - auto_mapping = getattr(peft_config, "auto_mapping", None) - base_model_class = auto_mapping["base_model_class"] - parent_library_name = auto_mapping["parent_library"] - - parent_library = importlib.import_module(parent_library_name) - target_class = getattr(parent_library, base_model_class) - return target_class.config_class - - @dataclass class RecordConfig: robot: RobotConfig @@ -167,23 +148,15 @@ class RecordConfig: if policy_path: cli_overrides = parser.get_cli_overrides("policy") + # In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the + # adapter parameters / config. If they didn't we could instantiate the default configuration for the policy + # but we wouldn't know if that is correct. So, in case of a missing config this will simply fail. + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + if (Path(policy_path) / 'adapter_config.json').exists(): - # The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the - # policy's config alongside the adapter config but to initialize the policy we - # need a policy config. We assume that the config hasn't changed and we infer - # the policy's config class from the base class mentioned in the adapter config. - self.peft_config = PeftConfig.from_pretrained(policy_path) - - policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config) - - self.policy = policy_config_class() - self.policy.pretrained_path = policy_path self.policy.use_peft = True - else: - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - if self.teleop is None and self.policy is None: raise ValueError("Choose a policy, a teleoperator or both to control the robot") @@ -314,6 +287,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy if cfg.policy.use_peft: + from peft import PeftModel + + logging.info("Loading policy's PEFT adapter.") # in case of PEFT we re-use the policy pretrained path to point to the adapter path. peft_path = cfg.policy.pretrained_path cfg.policy.pretrained_path = None diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index b78f6e497..198c5c4ea 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -79,6 +79,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): mock_save_training_state.assert_called_once() +@patch("lerobot.common.utils.train_utils.save_training_state") +def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): + policy = Mock() + policy.config = Mock() + policy.config.save_pretrained = Mock() + cfg = Mock() + cfg.use_peft = True + save_checkpoint(tmp_path, 10, cfg, policy, optimizer) + policy.save_pretrained.assert_called_once() + cfg.save_pretrained.assert_called_once() + policy.config.save_pretrained.assert_called_once() + mock_save_training_state.assert_called_once() + + def test_save_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) assert (tmp_path / TRAINING_STATE_DIR).is_dir()