diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b59d15734..381e95dc4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Train RL in Simulation - local: multi_gpu_training title: Multi GPU training + - local: peft_training + title: Training with PEFT (e.g., LoRA) title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/peft_training.mdx b/docs/source/peft_training.mdx new file mode 100644 index 000000000..dd0b10075 --- /dev/null +++ b/docs/source/peft_training.mdx @@ -0,0 +1,62 @@ +# Parameter efficient fine-tuning with 🤗 PEFT + +[🤗 PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting +large pretrained models such as pre-trained policies (e.g., SmolVLA, π₀, ...) to new tasks without training all +of the model's parameters while yielding comparable performance. + +Install the `lerobot[peft]` optional package to enable PEFT support. + +To read about all the possible methods of adaption, please refer to the [🤗 PEFT docs](https://huggingface.co/docs/peft/index). + +## Training SmolVLA + +In this section we'll show you how to train a pre-trained SmolVLA policy with PEFT on the libero dataset. +For brevity we're only training on the `libero_spatial` subset. We will use `lerobot/smolvla_base` as the model +to parameter efficiently fine-tune: + +``` +lerobot-train \ + --policy.path=lerobot/smolvla_base \ + --policy.repo_id=your_hub_name/my_libero_smolvla \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --policy.output_features=null \ + --policy.input_features=null \ + --policy.optimizer_lr=1e-3 \ + --policy.scheduler_decay_lr=1e-4 \ + --env.type=libero \ + --env.task=libero_spatial \ + --steps=100000 \ + --batch_size=32 \ + --peft.method_type=LORA \ + --peft.r=64 +``` + +Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use +[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most +popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank +instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank +the closer you get to full fine-tuning + +There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue +if you want to see a specific PEFT method supported. + +By default, PEFT will target the `q_proj` and `v_proj` layers of the LM expert in SmolVLA. It will also target the +state and action projection matrices as they are most likely task-dependent. If you need to target different layers +you can use `--peft.target_modules` to specify which layers to target. You can refer to the respective PEFT method's +documentation to see what inputs are supported, (e.g., [LoRA's target_modules documentation](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules)). +Usually a list of suffixes or a regex are supported. For example, to target the MLPs of the `lm_expert` instead of +the `q` and `v` projections, use: + +``` +--peft.target_modules='(model\.vlm_with_expert\.lm_expert\..*\.(down|gate|up)_proj|.*\.(state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out))' +``` + +In case you need to fully fine-tune a layer instead of just adapting it, you can supply a list of layer suffixes +to the `--peft.full_training_modules` parameter: + +``` +--peft.full_training_modules=["state_proj"] +``` + +The learning rate and the scheduled target learning rate can usually be scaled by a factor of 10 compared to the +learning rate used for full fine-tuning (e.g., 1e-4 normal, so 1e-3 using LoRA). diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 6f051485a..5f44649da 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -455,7 +455,18 @@ def demo_cli(cfg: RTCDemoConfig): if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": config.compile_model = cfg.use_torch_compile - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) + if config.use_peft: + from peft import PeftConfig, PeftModel + + peft_pretrained_path = cfg.policy.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_pretrained_path) + + policy = policy_class.from_pretrained( + pretrained_name_or_path=peft_config.base_model_name_or_path, config=config + ) + policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + else: + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) # Turn on RTC policy.config.rtc_config = cfg.rtc diff --git a/pyproject.toml b/pyproject.toml index 33586c5e4..980844149 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] +peft = ["lerobot[transformers-dep]", "peft>=0.18.0"] # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] @@ -182,7 +183,8 @@ all = [ "lerobot[phone]", "lerobot[libero]", "lerobot[metaworld]", - "lerobot[sarm]" + "lerobot[sarm]", + "lerobot[peft]", ] [project.scripts] @@ -417,6 +419,10 @@ conflicts = [ { extra = "wallx" }, { extra = "libero" }, ], + [ + { extra = "wallx" }, + { extra = "peft" }, + ], [ { extra = "wallx" }, { extra = "all" }, @@ -450,6 +456,10 @@ conflicts = [ { extra = "pi" }, { extra = "libero" }, ], + [ + { extra = "pi" }, + { extra = "peft" }, + ], [ { extra = "pi" }, { extra = "all" }, diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 630d63f1b..f613b5251 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -67,3 +67,31 @@ class EvalConfig: f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." ) + + +@dataclass +class PeftConfig: + # PEFT offers many fine-tuning methods, layer adapters being the most common and currently also the most + # effective methods so we'll focus on those in this high-level config interface. + + # Either a string (module name suffix or 'all-linear'), a list of module name suffixes or a regular expression + # describing module names to target with the configured PEFT method. Some policies have a default value for this + # so that you don't *have* to choose which layers to adapt but it might still be worthwhile depending on your case. + target_modules: list[str] | str | None = None + + # Names/suffixes of modules to fully fine-tune and store alongside adapter weights. Useful for layers that are + # not part of a pre-trained model (e.g., action state projections). Depending on the policy this defaults to layers + # that are newly created in pre-trained policies. If you're fine-tuning an already trained policy you might want + # to set this to `[]`. Corresponds to PEFT's `modules_to_save`. + full_training_modules: list[str] | None = None + + # The PEFT (adapter) method to apply to the policy. Needs to be a valid PEFT type. + method_type: str = "LORA" + + # Adapter initialization method. Look at the specific PEFT adapter documentation for defaults. + init_type: str | None = None + + # We expect that all PEFT adapters are in some way doing rank-decomposition therefore this parameter specifies + # the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full + # fine-tuning. + r: int = 16 diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 0ecfa169b..7f326b70b 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -55,14 +55,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno n_obs_steps: int = 1 - input_features: dict[str, PolicyFeature] = field(default_factory=dict) - output_features: dict[str, PolicyFeature] = field(default_factory=dict) + # `input_features` can be set to None/null in order to infer those values from the dataset. + input_features: dict[str, PolicyFeature] | None = field(default_factory=dict) + output_features: dict[str, PolicyFeature] | None = field(default_factory=dict) device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps" # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. use_amp: bool = False + # Whether the policy employed PEFT for training. + use_peft: bool = False + push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override repo_id: str | None = None @@ -125,6 +129,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def robot_state_feature(self) -> PolicyFeature | None: + if not self.input_features: + return None for ft_name, ft in self.input_features.items(): if ft.type is FeatureType.STATE and ft_name == OBS_STATE: return ft @@ -132,6 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def env_state_feature(self) -> PolicyFeature | None: + if not self.input_features: + return None for _, ft in self.input_features.items(): if ft.type is FeatureType.ENV: return ft @@ -139,10 +147,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def image_features(self) -> dict[str, PolicyFeature]: + if not self.input_features: + return {} return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} @property def action_feature(self) -> PolicyFeature | None: + if not self.output_features: + return None for ft_name, ft in self.output_features.items(): if ft.type is FeatureType.ACTION and ft_name == ACTION: return ft diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index cee9dfdf9..7a5eee77d 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot import envs from lerobot.configs import parser -from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig +from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.optim import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -65,6 +65,7 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + peft: PeftConfig | None = None # RA-BC (Reward-Aligned Behavior Cloning) parameters use_rabc: bool = False # Enable reward-weighted training diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3e24656fc..8c414f235 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -471,11 +471,40 @@ def make_policy( if ds_meta is not None: kwargs["dataset_meta"] = ds_meta - if cfg.pretrained_path: + if not cfg.pretrained_path and cfg.use_peft: + raise ValueError( + "Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires " + "the PEFT config parameters to be set. For training with PEFT, see `lerobot_train.py` on how to do that." + ) + + if cfg.pretrained_path and not cfg.use_peft: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). kwargs["pretrained_name_or_path"] = cfg.pretrained_path policy = policy_cls.from_pretrained(**kwargs) + elif cfg.pretrained_path and cfg.use_peft: + # Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo + # of the adapter and the adapter's config contains the path to the base policy. So we need the + # adapter config first, then load the correct policy and then apply PEFT. + from peft import PeftConfig, PeftModel + + logging.info("Loading policy's PEFT adapter.") + + peft_pretrained_path = cfg.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_pretrained_path) + + kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path + if not kwargs["pretrained_name_or_path"]: + # This means that there's a bug or we trained a policy from scratch using PEFT. + # It is more likely that this is a bug so we'll raise an error. + raise ValueError( + "No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which " + "the adapter was trained." + ) + + policy = policy_cls.from_pretrained(**kwargs) + policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + else: # Make a fresh policy. policy = policy_cls(**kwargs) diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 3f5d89ec5..a1499d077 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -206,6 +206,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def push_model_to_hub( self, cfg: TrainPipelineConfig, + peft_model=None, ): api = HfApi() repo_id = api.create_repo( @@ -216,7 +217,14 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: saved_path = Path(tmp) / repo_id - self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + if peft_model is not None: + # Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper + # but the actual policy which is why we need the PEFT model passed to us to save the adapter. + # That also means that we need to store the policy config ourselves since PEFT can't. + peft_model.save_pretrained(saved_path) + self.config.save_pretrained(saved_path) + else: + self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors card = self.generate_model_card( cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 1537b3783..7b7f8a57b 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -112,7 +112,32 @@ class WandBLogger: artifact_name = f"{self._group}-{step_id}" artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact = self._wandb.Artifact(artifact_name, type="model") - artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) + pretrained_model_dir = checkpoint_dir / PRETRAINED_MODEL_DIR + + # Check if this is a PEFT model (has adapter files instead of model.safetensors) + adapter_model_file = pretrained_model_dir / "adapter_model.safetensors" + standard_model_file = pretrained_model_dir / SAFETENSORS_SINGLE_FILE + + if adapter_model_file.exists(): + # PEFT model: add adapter files and configs + artifact.add_file(adapter_model_file) + adapter_config_file = pretrained_model_dir / "adapter_config.json" + if adapter_config_file.exists(): + artifact.add_file(adapter_config_file) + # Also add the policy config which is needed for loading + config_file = pretrained_model_dir / "config.json" + if config_file.exists(): + artifact.add_file(config_file) + elif standard_model_file.exists(): + # Standard model: add the single safetensors file + artifact.add_file(standard_model_file) + else: + logging.warning( + f"No {SAFETENSORS_SINGLE_FILE} or adapter_model.safetensors found in {pretrained_model_dir}. " + "Skipping model artifact upload to WandB." + ) + return + self._wandb.log_artifact(artifact) def log_dict( diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 9c40a1883..7b45e88e1 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -278,9 +278,16 @@ def eval_policy( raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") if not isinstance(policy, PreTrainedPolicy): - raise ValueError( + exc = ValueError( f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." ) + try: + from peft import PeftModel + + if not isinstance(policy, PeftModel): + raise exc + except ImportError: + raise exc from None start = time.time() policy.eval() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 948e92bb8..eafd32ace 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -193,8 +193,10 @@ class RecordConfig: def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") + if policy_path: cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 6cf733442..286c69906 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import logging import time from contextlib import nullcontext @@ -147,6 +148,92 @@ def update_policy( return train_metrics, output_dict +def get_default_peft_configuration(policy_type): + """Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint.""" + + common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + + if policy_type == "smolvla": + return { + "target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))", + "modules_to_save": [], + } + elif policy_type in ("pi0", "pi05"): + return { + "target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))", + "modules_to_save": [], + } + + return {"modules_to_save": None} + + +def wrap_policy_in_peft_model(cfg, policy): + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model + + # Disable all gradients because we'll only train the parameters selected by the PEFT method. + # Layers that should receive gradients anyway need to be listed in `modules_to_save`. + for p in policy.parameters(): + p.requires_grad_(False) + + if not cfg.policy.pretrained_path: + raise ValueError( + "Training from scratch using PEFT. This is unlikely to yield good results. " + "Supply a `policy.path` to fine-tune an existing model." + ) + + if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: + logging.warning( + "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " + "`load_vlm_weights=True` to fine-tune the existing policy." + ) + + peft_config_policy = get_default_peft_configuration(cfg.policy.type) + peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} + peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT + peft_method_type = PeftType[peft_config_cli["method_type"].upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Handle specific CLI overrides + for key in ["target_modules", "modules_to_save", "r"]: + if peft_config_cli[key] is not None: + peft_config_policy[key] = peft_config_cli[key] + + if "target_modules" not in peft_config_policy: + raise ValueError( + f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." + ) + + # Init method depends on the used PEFT method, your specific PEFT method + # might not be considered here, in that case an error is raised. + if peft_config_cli["init_type"] is not None: + if peft_method_type == "LORA": + peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] + elif peft_method_type == "MISS": + peft_config_policy["init_weights"] = peft_config_cli["init_type"] + else: + raise ValueError( + f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." + ) + + # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the + # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the + # adapter, not the base model. + if policy.config.pretrained_path: + policy.name_or_path = str(policy.config.pretrained_path) + + # Finally wrap the policy in a PEFT model + policy = get_peft_model( + policy, + peft_config_cls(**peft_config_policy), + ) + + # Make sure that the config is tagged as using PEFT so that the loading code can take the + # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. + policy.config.use_peft = True + + return policy + + @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -230,6 +317,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): rename_map=cfg.rename_map, ) + if cfg.peft is not None: + logging.info("Using PEFT! Wrapping model.") + policy = wrap_policy_in_peft_model(cfg, policy) + # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() @@ -502,7 +593,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.policy.push_to_hub: unwrapped_policy = accelerator.unwrap_model(policy) - unwrapped_policy.push_model_to_hub(cfg) + if cfg.policy.use_peft: + unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy) + else: + unwrapped_policy.push_model_to_hub(cfg) preprocessor.push_to_hub(cfg.policy.repo_id) postprocessor.push_to_hub(cfg.policy.repo_id) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 3ebe31971..d8481f4b9 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -99,6 +99,10 @@ def save_checkpoint( pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if cfg.peft is not None: + # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the + # policy config which we need for loading the model. In this case we'll write it ourselves. + policy.config.save_pretrained(pretrained_dir) if preprocessor is not None: preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 000000000..ddaae0c9b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,229 @@ +import importlib +import os +from unittest.mock import MagicMock, patch + +import pytest +from safetensors.torch import load_file + +from .utils import require_package + + +def run_command(cmd, module, args): + module = importlib.import_module(f"lerobot.scripts.{module}") + with patch("sys.argv", [cmd] + args): + module.main() + + +def lerobot_train(args): + return run_command(cmd="lerobot-train", module="lerobot_train", args=args) + + +def lerobot_record(args): + return run_command(cmd="lerobot-record", module="lerobot_record", args=args) + + +def resolve_model_id_for_peft_training(policy_type): + """PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training.""" + if policy_type == "smolvla": + return "lerobot/smolvla_base" + + raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.") + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_push_to_hub_works(policy_type, tmp_path): + """Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" + output_dir = tmp_path / f"output_{policy_type}" + upload_folder_contents = set() + + model_id = resolve_model_id_for_peft_training(policy_type) + + def mock_upload_folder(*args, **kwargs): + folder_path = kwargs["folder_path"] + # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders() + upload_folder_contents.update(os.listdir(folder_path)) + return MagicMock() + + with ( + patch("huggingface_hub.HfApi.create_repo"), + patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder), + ): + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=true", + "--policy.repo_id=foo/bar", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + assert "adapter_model.safetensors" in upload_folder_contents + assert "config.json" in upload_folder_contents + assert "adapter_config.json" in upload_folder_contents + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_works(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model" + + for file in ["adapter_config.json", "adapter_model.safetensors", "config.json"]: + assert (policy_dir / file).exists() + + # This is the default case where we train a pre-trained policy from scratch with new data. + # We assume that we target policy-specific modules but fully fine-tune action and state projections + # so these must be part of the trained state dict. + state_dict = load_file(policy_dir / "adapter_model.safetensors") + + adapted_keys = [ + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", + ] + + found_keys = [ + module_key + for module_key in adapted_keys + for state_dict_key in state_dict + if f".{module_key}." in state_dict_key + ] + + assert set(found_keys) == set(adapted_keys) + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_params_are_fewer(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + def dummy_update_policy( + train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs + ): + params_total = sum(p.numel() for p in policy.parameters()) + params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) + + assert params_total > params_trainable + + return train_metrics, {} + + with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + +class DummyRobot: + name = "dummy" + cameras = [] + action_features = {"foo": 1.0, "bar": 2.0} + observation_features = {"obs1": 1.0, "obs2": 2.0} + is_connected = True + + def connect(self, *args): + pass + + def disconnect(self): + pass + + +def dummy_make_robot_from_config(*args, **kwargs): + return DummyRobot() + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_record_loads_policy(policy_type, tmp_path): + """Train a policy with PEFT and attempt to load it with `lerobot-record`.""" + from peft import PeftModel + + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model" + dataset_dir = tmp_path / "eval_pusht" + single_task = "move the table" + loaded_policy = None + + def dummy_record_loop(*args, **kwargs): + nonlocal loaded_policy + + if "dataset" not in kwargs: + return + + dataset = kwargs["dataset"] + dataset.add_frame({"task": single_task}) + loaded_policy = kwargs["policy"] + + with ( + patch("lerobot.robots.make_robot_from_config", dummy_make_robot_from_config), + # disable record loop since we're only interested in successful loading of the policy. + patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop), + # disable speech output + patch("lerobot.utils.utils.say"), + ): + lerobot_record( + [ + f"--policy.path={policy_dir}", + "--robot.type=so101_follower", + "--robot.port=/dev/null", + "--dataset.repo_id=lerobot/eval_pusht", + f'--dataset.single_task="{single_task}"', + f"--dataset.root={dataset_dir}", + "--dataset.push_to_hub=false", + ] + ) + + assert isinstance(loaded_policy, PeftModel) diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 892503e97..4791caf58 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -82,6 +82,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): mock_save_training_state.assert_called_once() +@patch("lerobot.utils.train_utils.save_training_state") +def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): + policy = Mock() + policy.config = Mock() + policy.config.save_pretrained = Mock() + cfg = Mock() + cfg.use_peft = True + save_checkpoint(tmp_path, 10, cfg, policy, optimizer) + policy.save_pretrained.assert_called_once() + cfg.save_pretrained.assert_called_once() + policy.config.save_pretrained.assert_called_once() + mock_save_training_state.assert_called_once() + + def test_save_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) assert (tmp_path / TRAINING_STATE_DIR).is_dir()