diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1d8f21afc..7023cb1d0 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -138,22 +138,26 @@ def get_default_peft_configuration(policy_type): """ if policy_type == "smolvla": return { - "target_modules": r"(model\.vlm_with_expert\.lm_expert\..*\.(q_proj|v_proj)|model\.action_.*|model\.state_proj.*)", + "target_modules": r"(model\.vlm_with_expert\.lm_expert\..*\.(q_proj|v_proj))", "modules_to_save": [ - # These are inf on load otherwise - "normalize_inputs", - "normalize_targets", - "unnormalize_outputs", + # these are initialized randomly and need full-finetuning + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", ], } - elif policy_type == "act": + elif policy_type in ("pi0", "pi05"): return { - "target_modules": r"(.*_proj|.*\.action_head)", + "target_modules": r".*\.gemma_expert\..*\.self_attn.(q_proj|v_proj)", "modules_to_save": [ - # These are inf on load otherwise - "normalize_inputs", - "normalize_targets", - "unnormalize_outputs", + # these are initialized randomly and need full-finetuning + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", ], } @@ -170,7 +174,7 @@ def wrap_policy_in_peft_model(cfg, policy): peft_config_policy = get_default_peft_configuration(cfg.policy.type) peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} - peft_config_cli['modules_to_save'] = peft_config_cli['full_training_modules'] # compatibility with PEFT + peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT peft_method_type = PeftType[peft_config_cli["method_type"].upper()] peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]