mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Add basic support for PEFT adapter methods
This changes adds support for training policies with much less parameters by applying adapter methods such as LoRA on specific parts of the policies and therefore possibly higher learning rates / batch sizes. To make this as accessible as possible I thought it useful to provide defaults for `target_modules` and `modules_to_save`. Currently only SmolVLA has such defaults but when we agree that this change is useful I will set out to generate more such defaults. While the user can override these settings, they are expected to only change the peft_method, rank and init_type parameters.
This commit is contained in:
@@ -69,3 +69,27 @@ class EvalConfig:
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeftConfig:
|
||||
# PEFT offers many methods, layer adapters are the most common and currently also the most effective methods so
|
||||
# we'll focus on those in this high-level config interface.
|
||||
|
||||
# `target_modules` can be set by the user but default to specific values depending on the used policy. See
|
||||
# `get_peft_configuration` in `scripts/train.py`.
|
||||
#
|
||||
target_modules: list[str] | None = None
|
||||
|
||||
# Similarly to `target_modules` this will have policy-dependent defaults which the user can override.
|
||||
modules_to_save: list[str] | None = None
|
||||
|
||||
# The PEFT (adapter) method to apply to the policy.
|
||||
method_type: str = "LORA"
|
||||
|
||||
# Adapter initialization method. Look at the specific adapter method documentation for defaults.
|
||||
init_type: str | None = None
|
||||
|
||||
# We expect that all adapters are in some way doing rank-decomposition. This is not true, there are several
|
||||
# methods that don't but we're focussing on these methods for now.
|
||||
r: int = 16
|
||||
|
||||
@@ -74,6 +74,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
def get(self, name, default=None):
|
||||
return getattr(self, name, default)
|
||||
|
||||
def __contains__(self, name):
|
||||
return hasattr(self, name)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.common.optim import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
@@ -63,6 +63,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
use_peft: bool = False
|
||||
peft: PeftConfig = field(default_factory=PeftConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.checkpoint_path = None
|
||||
|
||||
+45
-3
@@ -44,6 +44,10 @@ from pprint import pformat
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from peft import PeftConfig, PeftModel
|
||||
import importlib
|
||||
|
||||
|
||||
from lerobot.common.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -144,10 +148,36 @@ class RecordConfig:
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if (policy_path / 'adapter_config.json').exists():
|
||||
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
|
||||
# policy's config alongside the adapter config but to initialize the policy we
|
||||
# need a policy config. We assume that the config hasn't changed and we infer
|
||||
# the policy's config class from the base class mentioned in the adapter config.
|
||||
self.peft_config = PeftConfig.from_pretrained(policy_path)
|
||||
|
||||
if getattr(self.peft_config, "auto_mapping", None) is None:
|
||||
raise ValueError(
|
||||
"No auto-mapping config found in adapter config. Cannot determine policy config."
|
||||
)
|
||||
|
||||
auto_mapping = getattr(self.peft_config, "auto_mapping", None)
|
||||
base_model_class = auto_mapping["base_model_class"]
|
||||
parent_library_name = auto_mapping["parent_library"]
|
||||
|
||||
parent_library = importlib.import_module(parent_library_name)
|
||||
target_class = getattr(parent_library, base_model_class)
|
||||
policy_config_class = target_class.config_class
|
||||
|
||||
self.policy = policy_config_class()
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
else:
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
@@ -277,7 +307,19 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
if cfg.use_peft:
|
||||
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
cfg.policy.pretrained_path = None
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
policy = PeftModel.from_pretrained(policy, peft_path)
|
||||
policy = policy.merge_and_unload()
|
||||
|
||||
else:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
@@ -105,6 +106,64 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def get_default_peft_configuration(policy_type):
|
||||
if policy_type == "smolvla":
|
||||
return {
|
||||
"target_modules": r"(model\.vlm_with_expert\.lm_expert\..*\.(q_proj|v_proj)|model\.action_.*|model\.state_proj.*)",
|
||||
"modules_to_save": [
|
||||
# These are inf on load otherwise
|
||||
"normalize_inputs",
|
||||
"normalize_targets",
|
||||
"unnormalize_outputs",
|
||||
],
|
||||
}
|
||||
|
||||
return {'modules_to_save': None}
|
||||
|
||||
|
||||
def wrap_policy_in_peft_model(cfg, policy):
|
||||
from peft import get_peft_model, PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
|
||||
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
|
||||
for p in policy.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
|
||||
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
|
||||
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Handle specific CLI overrides
|
||||
for key in ["target_modules", "modules_to_save", "r"]:
|
||||
if peft_config_cli[key] is not None:
|
||||
peft_config_policy[key] = peft_config_cli[key]
|
||||
|
||||
if 'target_modules' not in peft_config_policy:
|
||||
raise ValueError(
|
||||
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
|
||||
)
|
||||
|
||||
# Init method depends on the used PEFT method, your specific PEFT method
|
||||
# might not be considered here, in that case an error is raised.
|
||||
if peft_config_cli["init_type"] is not None:
|
||||
if peft_method_type == "LORA":
|
||||
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
||||
elif peft_method_type == "BONE":
|
||||
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
|
||||
)
|
||||
|
||||
policy = get_peft_model(
|
||||
policy,
|
||||
peft_config_cls(**peft_config_policy),
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
@@ -141,6 +200,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
if cfg.use_peft:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
policy = wrap_policy_in_peft_model(cfg, policy)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
|
||||
Reference in New Issue
Block a user