add torch.no_grad decorators

This commit is contained in:
Bryson Jones
2025-12-10 13:48:28 -08:00
parent 3b2a4f548c
commit 3a16a002f8
@@ -760,6 +760,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
@@ -778,6 +779,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
actions = self._generate_actions(batch)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch: