From 6d8ef7dc6039f1ec44baab1086bbce6f6e142b9e Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 3 Jul 2026 13:22:30 +0200 Subject: [PATCH] fix(autocast): route inference autocasts through safe helper Apply get_safe_autocast_context to the control_utils and sync inference paths for uniformity with lerobot_eval. AMP is now enabled on any AMP-capable device (cuda, xpu, cpu) when use_amp is set, and stays a no-op on mps. --- src/lerobot/common/control_utils.py | 4 ++-- src/lerobot/rollout/inference/sync.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/lerobot/common/control_utils.py b/src/lerobot/common/control_utils.py index e3130643d..d321338bb 100644 --- a/src/lerobot/common/control_utils.py +++ b/src/lerobot/common/control_utils.py @@ -18,7 +18,6 @@ from __future__ import annotations # Utilities ######################################################################################## import time -from contextlib import nullcontext from copy import copy from typing import TYPE_CHECKING, Any @@ -26,6 +25,7 @@ import numpy as np import torch from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference +from lerobot.utils.device_utils import get_safe_autocast_context from lerobot.utils.import_utils import _deepdiff_available, require_package if TYPE_CHECKING or _deepdiff_available: @@ -76,7 +76,7 @@ def predict_action( observation = copy(observation) with ( torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + get_safe_autocast_context(device, enabled=use_amp), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension observation = prepare_observation_for_inference(observation, device, task, robot_type) diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py index 2bb05b6ab..d890f9783 100644 --- a/src/lerobot/rollout/inference/sync.py +++ b/src/lerobot/rollout/inference/sync.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging -from contextlib import nullcontext from copy import copy import torch @@ -25,6 +24,7 @@ import torch from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference from lerobot.processor import PolicyProcessorPipeline +from lerobot.utils.device_utils import get_safe_autocast_context from .base import InferenceEngine @@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine): # ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the # tensor/array values are not shared with any other reader. observation = copy(obs_frame) - autocast_ctx = ( - torch.autocast(device_type=self._device.type) - if self._device.type == "cuda" and self._policy.config.use_amp - else nullcontext() - ) + autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp) with torch.inference_mode(), autocast_ctx: observation = prepare_observation_for_inference( observation, self._device, self._task, self._robot_type