From 3a16a002f869b8e2f289931633818652a4273ea3 Mon Sep 17 00:00:00 2001 From: Bryson Jones Date: Wed, 10 Dec 2025 13:48:28 -0800 Subject: [PATCH] add torch.no_grad decorators --- src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index c26c7b4af..ec9138bf7 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -760,6 +760,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): self._queues["task"] = deque(maxlen=self.config.n_obs_steps) + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations""" self.eval() @@ -778,6 +779,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): actions = self._generate_actions(batch) return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations""" if ACTION in batch: