Compare commits

...

2 Commits

Author SHA1 Message Date
CarolinePascal 6d8ef7dc60 fix(autocast): route inference autocasts through safe helper
Apply get_safe_autocast_context to the control_utils and sync inference
paths for uniformity with lerobot_eval. AMP is now enabled on any
AMP-capable device (cuda, xpu, cpu) when use_amp is set, and stays a
no-op on mps.
2026-07-03 13:22:30 +02:00
CarolinePascal ca6d764107 fix(autocast): gate autocast on AMP-capable devices
Add get_safe_autocast_context helper that only enters torch.autocast on
devices supporting AMP (cuda, xpu, cpu) and falls back to a no-op on mps
and unknown backends. Route the previously unconditional/underspecified
autocasts (vla_jepa, groot, molmoact2, lerobot_eval) through it so
autocast can be requested unconditionally without breaking on unsupported
devices.
2026-07-03 11:22:33 +02:00
9 changed files with 85 additions and 21 deletions
+2 -2
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
# Utilities
########################################################################################
import time
from contextlib import nullcontext
from copy import copy
from typing import TYPE_CHECKING, Any
@@ -26,6 +25,7 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
@@ -76,7 +76,7 @@ def predict_action(
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
get_safe_autocast_context(device, enabled=use_amp),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
+3 -2
View File
@@ -43,6 +43,7 @@ from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
@@ -243,7 +244,7 @@ class GrootPolicy(PreTrainedPolicy):
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
@@ -275,7 +276,7 @@ class GrootPolicy(PreTrainedPolicy):
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
@@ -31,7 +31,6 @@ import logging
import os
import types
from collections import deque
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -43,6 +42,7 @@ from torch.distributions import Beta
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
@@ -1644,10 +1644,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
device=device,
)
action_dim = self._output_action_dim(batch)
autocast_context = (
torch.autocast(device_type=device.type, dtype=model_dtype)
if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16}
else nullcontext()
autocast_context = get_safe_autocast_context(
device, dtype=model_dtype, enabled=model_dtype in {torch.bfloat16, torch.float16}
)
with autocast_context:
if inference_action_mode == "discrete":
@@ -26,6 +26,7 @@ from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
@@ -183,7 +184,7 @@ class VLAJEPAModel(nn.Module):
action_idx = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
with get_safe_autocast_context(device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
@@ -250,7 +251,7 @@ class VLAJEPAModel(nn.Module):
) -> Tensor:
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.float32):
with get_safe_autocast_context(device_type, dtype=torch.float32):
r = self.config.repeated_diffusion_steps
horizon = self.config.chunk_size
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
+2 -6
View File
@@ -17,7 +17,6 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
@@ -25,6 +24,7 @@ import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.device_utils import get_safe_autocast_context
from .base import InferenceEngine
@@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine):
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type
+2 -3
View File
@@ -56,7 +56,6 @@ import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from functools import partial
@@ -86,7 +85,7 @@ from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_proces
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.device_utils import get_safe_autocast_context, get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
@@ -698,7 +697,7 @@ def eval_main(cfg: EvalPipelineConfig):
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
with torch.no_grad(), get_safe_autocast_context(device, enabled=cfg.policy.use_amp):
info = eval_policy_all(
envs=envs,
policy=policy,
+7 -1
View File
@@ -33,7 +33,12 @@ from .constants import (
REWARD,
)
from .decorators import check_if_already_connected, check_if_not_connected
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
from .device_utils import (
auto_select_torch_device,
get_safe_autocast_context,
get_safe_torch_device,
is_torch_device_available,
)
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from .import_utils import is_package_available, require_package
@@ -51,6 +56,7 @@ __all__ = [
"REWARD",
# Device utilities
"auto_select_torch_device",
"get_safe_autocast_context",
"get_safe_torch_device",
"is_torch_device_available",
# Import guards
+23
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from contextlib import AbstractContextManager, nullcontext
import torch
@@ -107,3 +108,25 @@ def is_amp_available(device: str):
return False
else:
raise ValueError(f"Unknown device '{device}.")
def get_safe_autocast_context(
device: str | torch.device,
*,
dtype: torch.dtype | None = None,
enabled: bool = True,
) -> AbstractContextManager:
"""Return a ``torch.autocast`` context, or a no-op when AMP is unsupported on ``device``.
Autocast is only entered on devices that support AMP (cuda, xpu, cpu); on mps and any
unknown device this falls back to ``nullcontext()`` so callers can request autocast
unconditionally without breaking on unsupported backends.
"""
device_type = device.type if isinstance(device, torch.device) else str(device).split(":", 1)[0]
try:
amp_ok = is_amp_available(device_type)
except ValueError:
amp_ok = False
if not enabled or not amp_ok:
return nullcontext()
return torch.autocast(device_type=device_type, dtype=dtype)
+40
View File
@@ -0,0 +1,40 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import pytest
import torch
from lerobot.utils.device_utils import get_safe_autocast_context
@pytest.mark.parametrize(
("device", "enabled", "expect_autocast"),
[
("cpu", True, True), # AMP-capable device -> real autocast
(torch.device("cpu"), True, True), # accepts torch.device
("cpu", False, False), # explicitly disabled -> no-op
("mps", True, False), # AMP unsupported on mps -> no-op
("privateuseone", True, False), # unknown device -> safe no-op
],
)
def test_get_safe_autocast_context(device, enabled, expect_autocast):
ctx = get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=enabled)
if expect_autocast:
assert isinstance(ctx, torch.autocast)
with ctx:
assert torch.is_autocast_enabled("cpu")
else:
assert isinstance(ctx, nullcontext)