diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 1ec290fe3..8f172d0ea 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -117,6 +117,16 @@ def get_default_peft_configuration(policy_type): "unnormalize_outputs", ], } + elif policy_type == "act": + return { + "target_modules": r"(.*_proj|.*\.action_head)", + "modules_to_save": [ + # These are inf on load otherwise + "normalize_inputs", + "normalize_targets", + "unnormalize_outputs", + ], + } return {'modules_to_save': None}