fix(autocast): route inference autocasts through safe helper

Apply get_safe_autocast_context to the control_utils and sync inference
paths for uniformity with lerobot_eval. AMP is now enabled on any
AMP-capable device (cuda, xpu, cpu) when use_amp is set, and stays a
no-op on mps.
This commit is contained in:
CarolinePascal
2026-07-03 13:22:30 +02:00
parent ca6d764107
commit 6d8ef7dc60
2 changed files with 4 additions and 8 deletions
+2 -2
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
# Utilities
########################################################################################
import time
from contextlib import nullcontext
from copy import copy
from typing import TYPE_CHECKING, Any
@@ -26,6 +25,7 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
@@ -76,7 +76,7 @@ def predict_action(
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
get_safe_autocast_context(device, enabled=use_amp),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
+2 -6
View File
@@ -17,7 +17,6 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
@@ -25,6 +24,7 @@ import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.device_utils import get_safe_autocast_context
from .base import InferenceEngine
@@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine):
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type