mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d8ef7dc60 | |||
| ca6d764107 |
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -26,6 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
from lerobot.utils.import_utils import _deepdiff_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _deepdiff_available:
|
||||
@@ -76,7 +76,7 @@ def predict_action(
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
get_safe_autocast_context(device, enabled=use_amp),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
|
||||
@@ -43,6 +43,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
@@ -243,7 +244,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.forward(groot_inputs)
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
@@ -275,7 +276,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
@@ -31,7 +31,6 @@ import logging
|
||||
import os
|
||||
import types
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
@@ -43,6 +42,7 @@ from torch.distributions import Beta
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
|
||||
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
@@ -1644,10 +1644,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
device=device,
|
||||
)
|
||||
action_dim = self._output_action_dim(batch)
|
||||
autocast_context = (
|
||||
torch.autocast(device_type=device.type, dtype=model_dtype)
|
||||
if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16}
|
||||
else nullcontext()
|
||||
autocast_context = get_safe_autocast_context(
|
||||
device, dtype=model_dtype, enabled=model_dtype in {torch.bfloat16, torch.float16}
|
||||
)
|
||||
with autocast_context:
|
||||
if inference_action_mode == "discrete":
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor, nn
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -183,7 +184,7 @@ class VLAJEPAModel(nn.Module):
|
||||
action_idx = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
with get_safe_autocast_context(device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
|
||||
@@ -250,7 +251,7 @@ class VLAJEPAModel(nn.Module):
|
||||
) -> Tensor:
|
||||
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
|
||||
device_type = next(self.parameters()).device.type
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
with get_safe_autocast_context(device_type, dtype=torch.float32):
|
||||
r = self.config.repeated_diffusion_steps
|
||||
horizon = self.config.chunk_size
|
||||
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
@@ -25,6 +24,7 @@ import torch
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
@@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||
# tensor/array values are not shared with any other reader.
|
||||
observation = copy(obs_frame)
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type=self._device.type)
|
||||
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||
else nullcontext()
|
||||
)
|
||||
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
|
||||
with torch.inference_mode(), autocast_ctx:
|
||||
observation = prepare_observation_for_inference(
|
||||
observation, self._device, self._task, self._robot_type
|
||||
|
||||
@@ -56,7 +56,6 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from functools import partial
|
||||
@@ -86,7 +85,7 @@ from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_proces
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context, get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -698,7 +697,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
with torch.no_grad(), get_safe_autocast_context(device, enabled=cfg.policy.use_amp):
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
|
||||
@@ -33,7 +33,12 @@ from .constants import (
|
||||
REWARD,
|
||||
)
|
||||
from .decorators import check_if_already_connected, check_if_not_connected
|
||||
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
|
||||
from .device_utils import (
|
||||
auto_select_torch_device,
|
||||
get_safe_autocast_context,
|
||||
get_safe_torch_device,
|
||||
is_torch_device_available,
|
||||
)
|
||||
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from .import_utils import is_package_available, require_package
|
||||
|
||||
@@ -51,6 +56,7 @@ __all__ = [
|
||||
"REWARD",
|
||||
# Device utilities
|
||||
"auto_select_torch_device",
|
||||
"get_safe_autocast_context",
|
||||
"get_safe_torch_device",
|
||||
"is_torch_device_available",
|
||||
# Import guards
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
@@ -107,3 +108,25 @@ def is_amp_available(device: str):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def get_safe_autocast_context(
|
||||
device: str | torch.device,
|
||||
*,
|
||||
dtype: torch.dtype | None = None,
|
||||
enabled: bool = True,
|
||||
) -> AbstractContextManager:
|
||||
"""Return a ``torch.autocast`` context, or a no-op when AMP is unsupported on ``device``.
|
||||
|
||||
Autocast is only entered on devices that support AMP (cuda, xpu, cpu); on mps and any
|
||||
unknown device this falls back to ``nullcontext()`` so callers can request autocast
|
||||
unconditionally without breaking on unsupported backends.
|
||||
"""
|
||||
device_type = device.type if isinstance(device, torch.device) else str(device).split(":", 1)[0]
|
||||
try:
|
||||
amp_ok = is_amp_available(device_type)
|
||||
except ValueError:
|
||||
amp_ok = False
|
||||
if not enabled or not amp_ok:
|
||||
return nullcontext()
|
||||
return torch.autocast(device_type=device_type, dtype=dtype)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.device_utils import get_safe_autocast_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("device", "enabled", "expect_autocast"),
|
||||
[
|
||||
("cpu", True, True), # AMP-capable device -> real autocast
|
||||
(torch.device("cpu"), True, True), # accepts torch.device
|
||||
("cpu", False, False), # explicitly disabled -> no-op
|
||||
("mps", True, False), # AMP unsupported on mps -> no-op
|
||||
("privateuseone", True, False), # unknown device -> safe no-op
|
||||
],
|
||||
)
|
||||
def test_get_safe_autocast_context(device, enabled, expect_autocast):
|
||||
ctx = get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=enabled)
|
||||
if expect_autocast:
|
||||
assert isinstance(ctx, torch.autocast)
|
||||
with ctx:
|
||||
assert torch.is_autocast_enabled("cpu")
|
||||
else:
|
||||
assert isinstance(ctx, nullcontext)
|
||||
Reference in New Issue
Block a user