Add basic support for PEFT adapter methods

This changes adds support for training policies with much less parameters
by applying adapter methods such as LoRA on specific parts of the policies
and therefore possibly higher learning rates / batch sizes.

To make this as accessible as possible I thought it useful to provide
defaults for `target_modules` and `modules_to_save`. Currently only SmolVLA
has such defaults but when we agree that this change is useful I will set
out to generate more such defaults. While the user can override these
settings, they are expected to only change the peft_method, rank and init_type
parameters.
This commit is contained in:
nemo
2025-06-22 13:45:07 +02:00
parent c940676bdd
commit 98856662c1
5 changed files with 141 additions and 4 deletions
+24
View File
@@ -69,3 +69,27 @@ class EvalConfig:
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
)
@dataclass
class PeftConfig:
# PEFT offers many methods, layer adapters are the most common and currently also the most effective methods so
# we'll focus on those in this high-level config interface.
# `target_modules` can be set by the user but default to specific values depending on the used policy. See
# `get_peft_configuration` in `scripts/train.py`.
#
target_modules: list[str] | None = None
# Similarly to `target_modules` this will have policy-dependent defaults which the user can override.
modules_to_save: list[str] | None = None
# The PEFT (adapter) method to apply to the policy.
method_type: str = "LORA"
# Adapter initialization method. Look at the specific adapter method documentation for defaults.
init_type: str | None = None
# We expect that all adapters are in some way doing rank-decomposition. This is not true, there are several
# methods that don't but we're focussing on these methods for now.
r: int = 16
+6
View File
@@ -74,6 +74,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
)
self.use_amp = False
def get(self, name, default=None):
return getattr(self, name, default)
def __contains__(self, name):
return hasattr(self, name)
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
+3 -1
View File
@@ -26,7 +26,7 @@ from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"
@@ -63,6 +63,8 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
use_peft: bool = False
peft: PeftConfig = field(default_factory=PeftConfig)
def __post_init__(self):
self.checkpoint_path = None
+45 -3
View File
@@ -44,6 +44,10 @@ from pprint import pformat
import numpy as np
import rerun as rr
from peft import PeftConfig, PeftModel
import importlib
from lerobot.common.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
)
@@ -144,10 +148,36 @@ class RecordConfig:
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if (policy_path / 'adapter_config.json').exists():
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
# policy's config alongside the adapter config but to initialize the policy we
# need a policy config. We assume that the config hasn't changed and we infer
# the policy's config class from the base class mentioned in the adapter config.
self.peft_config = PeftConfig.from_pretrained(policy_path)
if getattr(self.peft_config, "auto_mapping", None) is None:
raise ValueError(
"No auto-mapping config found in adapter config. Cannot determine policy config."
)
auto_mapping = getattr(self.peft_config, "auto_mapping", None)
base_model_class = auto_mapping["base_model_class"]
parent_library_name = auto_mapping["parent_library"]
parent_library = importlib.import_module(parent_library_name)
target_class = getattr(parent_library, base_model_class)
policy_config_class = target_class.config_class
self.policy = policy_config_class()
self.policy.pretrained_path = policy_path
else:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
if self.teleop is None and self.policy is None:
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
@@ -277,7 +307,19 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if cfg.use_peft:
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
peft_path = cfg.policy.pretrained_path
cfg.policy.pretrained_path = None
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
policy = PeftModel.from_pretrained(policy, peft_path)
policy = policy.merge_and_unload()
else:
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
robot.connect()
if teleop is not None:
+63
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
import time
from contextlib import nullcontext
@@ -105,6 +106,64 @@ def update_policy(
return train_metrics, output_dict
def get_default_peft_configuration(policy_type):
if policy_type == "smolvla":
return {
"target_modules": r"(model\.vlm_with_expert\.lm_expert\..*\.(q_proj|v_proj)|model\.action_.*|model\.state_proj.*)",
"modules_to_save": [
# These are inf on load otherwise
"normalize_inputs",
"normalize_targets",
"unnormalize_outputs",
],
}
return {'modules_to_save': None}
def wrap_policy_in_peft_model(cfg, policy):
from peft import get_peft_model, PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
for p in policy.parameters():
p.requires_grad_(False)
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
# Handle specific CLI overrides
for key in ["target_modules", "modules_to_save", "r"]:
if peft_config_cli[key] is not None:
peft_config_policy[key] = peft_config_cli[key]
if 'target_modules' not in peft_config_policy:
raise ValueError(
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
)
# Init method depends on the used PEFT method, your specific PEFT method
# might not be considered here, in that case an error is raised.
if peft_config_cli["init_type"] is not None:
if peft_method_type == "LORA":
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
elif peft_method_type == "BONE":
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
else:
raise ValueError(
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
)
policy = get_peft_model(
policy,
peft_config_cls(**peft_config_policy),
)
return policy
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
@@ -141,6 +200,10 @@ def train(cfg: TrainPipelineConfig):
ds_meta=dataset.meta,
)
if cfg.use_peft:
logging.info("Using PEFT! Wrapping model.")
policy = wrap_policy_in_peft_model(cfg, policy)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)