diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 0d252b296..4fee851e0 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -64,7 +64,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): config_class = MultiTaskDiTConfig name = "multi_task_dit" - def __init__(self, config: MultiTaskDiTConfig): + def __init__(self, config: MultiTaskDiTConfig, **kwargs): super().__init__(config) config.validate_features() self.config = config