mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
fix(autocast): route inference autocasts through safe helper
Apply get_safe_autocast_context to the control_utils and sync inference paths for uniformity with lerobot_eval. AMP is now enabled on any AMP-capable device (cuda, xpu, cpu) when use_amp is set, and stays a no-op on mps.
This commit is contained in:
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -26,6 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
from lerobot.utils.import_utils import _deepdiff_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _deepdiff_available:
|
||||
@@ -76,7 +76,7 @@ def predict_action(
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
get_safe_autocast_context(device, enabled=use_amp),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
@@ -25,6 +24,7 @@ import torch
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
@@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||
# tensor/array values are not shared with any other reader.
|
||||
observation = copy(obs_frame)
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type=self._device.type)
|
||||
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||
else nullcontext()
|
||||
)
|
||||
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
|
||||
with torch.inference_mode(), autocast_ctx:
|
||||
observation = prepare_observation_for_inference(
|
||||
observation, self._device, self._task, self._robot_type
|
||||
|
||||
Reference in New Issue
Block a user