mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Store policy config alongside PEFT checkpoint
Before this change the PEFT-wrapped policy did not save the policy's config alongside the adapter config / weights which prevented us from changing the policy config. Now the policy config is saved both in full training and PEFT training. This change makes loading the PEFT policy adapter much easier as well.
This commit is contained in:
@@ -99,6 +99,10 @@ def save_checkpoint(
|
|||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir)
|
||||||
cfg.save_pretrained(pretrained_dir)
|
cfg.save_pretrained(pretrained_dir)
|
||||||
|
if cfg.use_peft:
|
||||||
|
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||||
|
# policy config which we need for loading the model. In this case we'll write it ourselves.
|
||||||
|
policy.config.save_pretrained(pretrained_dir)
|
||||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+9
-33
@@ -44,10 +44,6 @@ from pprint import pformat
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
|
|
||||||
from peft import PeftConfig, PeftModel
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
|
|
||||||
from lerobot.common.cameras import ( # noqa: F401
|
from lerobot.common.cameras import ( # noqa: F401
|
||||||
CameraConfig, # noqa: F401
|
CameraConfig, # noqa: F401
|
||||||
)
|
)
|
||||||
@@ -130,21 +126,6 @@ class DatasetRecordConfig:
|
|||||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||||
|
|
||||||
|
|
||||||
def get_policy_config_from_peft_checkpoint(peft_config):
|
|
||||||
if getattr(peft_config, "auto_mapping", None) is None:
|
|
||||||
raise ValueError(
|
|
||||||
"No auto-mapping config found in adapter config. Cannot determine policy config."
|
|
||||||
)
|
|
||||||
|
|
||||||
auto_mapping = getattr(peft_config, "auto_mapping", None)
|
|
||||||
base_model_class = auto_mapping["base_model_class"]
|
|
||||||
parent_library_name = auto_mapping["parent_library"]
|
|
||||||
|
|
||||||
parent_library = importlib.import_module(parent_library_name)
|
|
||||||
target_class = getattr(parent_library, base_model_class)
|
|
||||||
return target_class.config_class
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RecordConfig:
|
class RecordConfig:
|
||||||
robot: RobotConfig
|
robot: RobotConfig
|
||||||
@@ -167,23 +148,15 @@ class RecordConfig:
|
|||||||
if policy_path:
|
if policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
|
|
||||||
|
# In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the
|
||||||
|
# adapter parameters / config. If they didn't we could instantiate the default configuration for the policy
|
||||||
|
# but we wouldn't know if that is correct. So, in case of a missing config this will simply fail.
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
if (Path(policy_path) / 'adapter_config.json').exists():
|
if (Path(policy_path) / 'adapter_config.json').exists():
|
||||||
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
|
|
||||||
# policy's config alongside the adapter config but to initialize the policy we
|
|
||||||
# need a policy config. We assume that the config hasn't changed and we infer
|
|
||||||
# the policy's config class from the base class mentioned in the adapter config.
|
|
||||||
self.peft_config = PeftConfig.from_pretrained(policy_path)
|
|
||||||
|
|
||||||
policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config)
|
|
||||||
|
|
||||||
self.policy = policy_config_class()
|
|
||||||
self.policy.pretrained_path = policy_path
|
|
||||||
self.policy.use_peft = True
|
self.policy.use_peft = True
|
||||||
|
|
||||||
else:
|
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
||||||
self.policy.pretrained_path = policy_path
|
|
||||||
|
|
||||||
if self.teleop is None and self.policy is None:
|
if self.teleop is None and self.policy is None:
|
||||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||||
|
|
||||||
@@ -314,6 +287,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
|
|
||||||
if cfg.policy.use_peft:
|
if cfg.policy.use_peft:
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
logging.info("Loading policy's PEFT adapter.")
|
||||||
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
||||||
peft_path = cfg.policy.pretrained_path
|
peft_path = cfg.policy.pretrained_path
|
||||||
cfg.policy.pretrained_path = None
|
cfg.policy.pretrained_path = None
|
||||||
|
|||||||
@@ -79,6 +79,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
|||||||
mock_save_training_state.assert_called_once()
|
mock_save_training_state.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("lerobot.common.utils.train_utils.save_training_state")
|
||||||
|
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
|
||||||
|
policy = Mock()
|
||||||
|
policy.config = Mock()
|
||||||
|
policy.config.save_pretrained = Mock()
|
||||||
|
cfg = Mock()
|
||||||
|
cfg.use_peft = True
|
||||||
|
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
|
||||||
|
policy.save_pretrained.assert_called_once()
|
||||||
|
cfg.save_pretrained.assert_called_once()
|
||||||
|
policy.config.save_pretrained.assert_called_once()
|
||||||
|
mock_save_training_state.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_save_training_state(tmp_path, optimizer, scheduler):
|
def test_save_training_state(tmp_path, optimizer, scheduler):
|
||||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||||
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
|
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
|
||||||
|
|||||||
Reference in New Issue
Block a user