diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index ce72466a8..babba4d7a 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -69,3 +69,27 @@ class EvalConfig: f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." ) + + +@dataclass +class PeftConfig: + # PEFT offers many methods, layer adapters are the most common and currently also the most effective methods so + # we'll focus on those in this high-level config interface. + + # `target_modules` can be set by the user but default to specific values depending on the used policy. See + # `get_peft_configuration` in `scripts/train.py`. + # + target_modules: list[str] | None = None + + # Similarly to `target_modules` this will have policy-dependent defaults which the user can override. + modules_to_save: list[str] | None = None + + # The PEFT (adapter) method to apply to the policy. + method_type: str = "LORA" + + # Adapter initialization method. Look at the specific adapter method documentation for defaults. + init_type: str | None = None + + # We expect that all adapters are in some way doing rank-decomposition. This is not true, there are several + # methods that don't but we're focussing on these methods for now. + r: int = 16 diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 1302db1fa..a85ca187f 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -74,6 +74,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): ) self.use_amp = False + def get(self, name, default=None): + return getattr(self, name, default) + + def __contains__(self, name): + return hasattr(self, name) + @property def type(self) -> str: return self.get_choice_name(self.__class__) diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 96a460bdf..de59f459a 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -26,7 +26,7 @@ from lerobot.common.optim import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.utils.hub import HubMixin from lerobot.configs import parser -from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig +from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig TRAIN_CONFIG_NAME = "train_config.json" @@ -63,6 +63,8 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + use_peft: bool = False + peft: PeftConfig = field(default_factory=PeftConfig) def __post_init__(self): self.checkpoint_path = None diff --git a/lerobot/record.py b/lerobot/record.py index acc844ff9..8ae34d76c 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -44,6 +44,10 @@ from pprint import pformat import numpy as np import rerun as rr +from peft import PeftConfig, PeftModel +import importlib + + from lerobot.common.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) @@ -144,10 +148,36 @@ class RecordConfig: def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") + if policy_path: cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path + + if (policy_path / 'adapter_config.json').exists(): + # The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the + # policy's config alongside the adapter config but to initialize the policy we + # need a policy config. We assume that the config hasn't changed and we infer + # the policy's config class from the base class mentioned in the adapter config. + self.peft_config = PeftConfig.from_pretrained(policy_path) + + if getattr(self.peft_config, "auto_mapping", None) is None: + raise ValueError( + "No auto-mapping config found in adapter config. Cannot determine policy config." + ) + + auto_mapping = getattr(self.peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + policy_config_class = target_class.config_class + + self.policy = policy_config_class() + self.policy.pretrained_path = policy_path + + else: + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path if self.teleop is None and self.policy is None: raise ValueError("Choose a policy, a teleoperator or both to control the robot") @@ -277,7 +307,19 @@ def record(cfg: RecordConfig) -> LeRobotDataset: ) # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + + if cfg.use_peft: + # in case of PEFT we re-use the policy pretrained path to point to the adapter path. + peft_path = cfg.policy.pretrained_path + cfg.policy.pretrained_path = None + + policy = make_policy(cfg.policy, ds_meta=dataset.meta) + + policy = PeftModel.from_pretrained(policy, peft_path) + policy = policy.merge_and_unload() + + else: + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) robot.connect() if teleop is not None: diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 0de247be9..b9fe2a8ea 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import logging import time from contextlib import nullcontext @@ -105,6 +106,64 @@ def update_policy( return train_metrics, output_dict +def get_default_peft_configuration(policy_type): + if policy_type == "smolvla": + return { + "target_modules": r"(model\.vlm_with_expert\.lm_expert\..*\.(q_proj|v_proj)|model\.action_.*|model\.state_proj.*)", + "modules_to_save": [ + # These are inf on load otherwise + "normalize_inputs", + "normalize_targets", + "unnormalize_outputs", + ], + } + + return {'modules_to_save': None} + + +def wrap_policy_in_peft_model(cfg, policy): + from peft import get_peft_model, PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Disable all gradients because we'll only train the parameters selected by the PEFT method. + # Layers that should receive gradients anyway need to be listed in `modules_to_save`. + for p in policy.parameters(): + p.requires_grad_(False) + + peft_config_policy = get_default_peft_configuration(cfg.policy.type) + peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} + peft_method_type = PeftType[peft_config_cli["method_type"].upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Handle specific CLI overrides + for key in ["target_modules", "modules_to_save", "r"]: + if peft_config_cli[key] is not None: + peft_config_policy[key] = peft_config_cli[key] + + if 'target_modules' not in peft_config_policy: + raise ValueError( + f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." + ) + + # Init method depends on the used PEFT method, your specific PEFT method + # might not be considered here, in that case an error is raised. + if peft_config_cli["init_type"] is not None: + if peft_method_type == "LORA": + peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] + elif peft_method_type == "BONE": + peft_config_policy["init_weights"] = peft_config_cli["init_type"] + else: + raise ValueError( + f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." + ) + + policy = get_peft_model( + policy, + peft_config_cls(**peft_config_policy), + ) + + return policy + + @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() @@ -141,6 +200,10 @@ def train(cfg: TrainPipelineConfig): ds_meta=dataset.meta, ) + if cfg.use_peft: + logging.info("Using PEFT! Wrapping model.") + policy = wrap_policy_in_peft_model(cfg, policy) + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)