diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index c26c7b4af..ec9138bf7 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -760,6 +760,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): self._queues["task"] = deque(maxlen=self.config.n_obs_steps) + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations""" self.eval() @@ -778,6 +779,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): actions = self._generate_actions(batch) return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations""" if ACTION in batch: