mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 06:07:40 +00:00
add torch.no_grad decorators
This commit is contained in:
@@ -760,6 +760,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
|
||||
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations"""
|
||||
self.eval()
|
||||
@@ -778,6 +779,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations"""
|
||||
if ACTION in batch:
|
||||
|
||||
Reference in New Issue
Block a user