Store policy config alongside PEFT checkpoint

Before this change the PEFT-wrapped policy did not save the policy's config
alongside the adapter config / weights which prevented us from changing the
policy config. Now the policy config is saved both in full training and PEFT
training.

This change makes loading the PEFT policy adapter much easier as well.
This commit is contained in:
nemo
2025-06-22 19:46:10 +02:00
parent 7fd8b4c773
commit dc67b2ff3f
3 changed files with 27 additions and 33 deletions
+4
View File
@@ -99,6 +99,10 @@ def save_checkpoint(
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir) policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir)
if cfg.use_peft:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
# policy config which we need for loading the model. In this case we'll write it ourselves.
policy.config.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler) save_training_state(checkpoint_dir, step, optimizer, scheduler)
+9 -33
View File
@@ -44,10 +44,6 @@ from pprint import pformat
import numpy as np import numpy as np
import rerun as rr import rerun as rr
from peft import PeftConfig, PeftModel
import importlib
from lerobot.common.cameras import ( # noqa: F401 from lerobot.common.cameras import ( # noqa: F401
CameraConfig, # noqa: F401 CameraConfig, # noqa: F401
) )
@@ -130,21 +126,6 @@ class DatasetRecordConfig:
raise ValueError("You need to provide a task as argument in `single_task`.") raise ValueError("You need to provide a task as argument in `single_task`.")
def get_policy_config_from_peft_checkpoint(peft_config):
if getattr(peft_config, "auto_mapping", None) is None:
raise ValueError(
"No auto-mapping config found in adapter config. Cannot determine policy config."
)
auto_mapping = getattr(peft_config, "auto_mapping", None)
base_model_class = auto_mapping["base_model_class"]
parent_library_name = auto_mapping["parent_library"]
parent_library = importlib.import_module(parent_library_name)
target_class = getattr(parent_library, base_model_class)
return target_class.config_class
@dataclass @dataclass
class RecordConfig: class RecordConfig:
robot: RobotConfig robot: RobotConfig
@@ -167,23 +148,15 @@ class RecordConfig:
if policy_path: if policy_path:
cli_overrides = parser.get_cli_overrides("policy") cli_overrides = parser.get_cli_overrides("policy")
# In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the
# adapter parameters / config. If they didn't we could instantiate the default configuration for the policy
# but we wouldn't know if that is correct. So, in case of a missing config this will simply fail.
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if (Path(policy_path) / 'adapter_config.json').exists(): if (Path(policy_path) / 'adapter_config.json').exists():
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
# policy's config alongside the adapter config but to initialize the policy we
# need a policy config. We assume that the config hasn't changed and we infer
# the policy's config class from the base class mentioned in the adapter config.
self.peft_config = PeftConfig.from_pretrained(policy_path)
policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config)
self.policy = policy_config_class()
self.policy.pretrained_path = policy_path
self.policy.use_peft = True self.policy.use_peft = True
else:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None: if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot") raise ValueError("Choose a policy, a teleoperator or both to control the robot")
@@ -314,6 +287,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy # Load pretrained policy
if cfg.policy.use_peft: if cfg.policy.use_peft:
from peft import PeftModel
logging.info("Loading policy's PEFT adapter.")
# in case of PEFT we re-use the policy pretrained path to point to the adapter path. # in case of PEFT we re-use the policy pretrained path to point to the adapter path.
peft_path = cfg.policy.pretrained_path peft_path = cfg.policy.pretrained_path
cfg.policy.pretrained_path = None cfg.policy.pretrained_path = None
+14
View File
@@ -79,6 +79,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
mock_save_training_state.assert_called_once() mock_save_training_state.assert_called_once()
@patch("lerobot.common.utils.train_utils.save_training_state")
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
policy.config = Mock()
policy.config.save_pretrained = Mock()
cfg = Mock()
cfg.use_peft = True
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
policy.save_pretrained.assert_called_once()
cfg.save_pretrained.assert_called_once()
policy.config.save_pretrained.assert_called_once()
mock_save_training_state.assert_called_once()
def test_save_training_state(tmp_path, optimizer, scheduler): def test_save_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler) save_training_state(tmp_path, 10, optimizer, scheduler)
assert (tmp_path / TRAINING_STATE_DIR).is_dir() assert (tmp_path / TRAINING_STATE_DIR).is_dir()