Compare commits

..

1 Commits

Author SHA1 Message Date
github-actions[bot] 62067f8eb9 chore(dependencies): update uv.lock 2026-07-03 08:28:33 +00:00
10 changed files with 880 additions and 1020 deletions
+2 -2
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
# Utilities
########################################################################################
import time
from contextlib import nullcontext
from copy import copy
from typing import TYPE_CHECKING, Any
@@ -25,7 +26,6 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
@@ -76,7 +76,7 @@ def predict_action(
observation = copy(observation)
with (
torch.inference_mode(),
get_safe_autocast_context(device, enabled=use_amp),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
+2 -3
View File
@@ -43,7 +43,6 @@ from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
@@ -244,7 +243,7 @@ class GrootPolicy(PreTrainedPolicy):
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
@@ -276,7 +275,7 @@ class GrootPolicy(PreTrainedPolicy):
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
@@ -31,6 +31,7 @@ import logging
import os
import types
from collections import deque
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -42,7 +43,6 @@ from torch.distributions import Beta
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
@@ -1644,8 +1644,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
device=device,
)
action_dim = self._output_action_dim(batch)
autocast_context = get_safe_autocast_context(
device, dtype=model_dtype, enabled=model_dtype in {torch.bfloat16, torch.float16}
autocast_context = (
torch.autocast(device_type=device.type, dtype=model_dtype)
if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16}
else nullcontext()
)
with autocast_context:
if inference_action_mode == "discrete":
@@ -26,7 +26,6 @@ from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
@@ -184,7 +183,7 @@ class VLAJEPAModel(nn.Module):
action_idx = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with get_safe_autocast_context(device_type, dtype=torch.bfloat16):
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
@@ -251,7 +250,7 @@ class VLAJEPAModel(nn.Module):
) -> Tensor:
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
device_type = next(self.parameters()).device.type
with get_safe_autocast_context(device_type, dtype=torch.float32):
with torch.autocast(device_type=device_type, dtype=torch.float32):
r = self.config.repeated_diffusion_steps
horizon = self.config.chunk_size
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
+6 -2
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
@@ -24,7 +25,6 @@ import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.device_utils import get_safe_autocast_context
from .base import InferenceEngine
@@ -102,7 +102,11 @@ class SyncInferenceEngine(InferenceEngine):
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type
+3 -2
View File
@@ -56,6 +56,7 @@ import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from functools import partial
@@ -85,7 +86,7 @@ from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_proces
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_autocast_context, get_safe_torch_device
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
@@ -697,7 +698,7 @@ def eval_main(cfg: EvalPipelineConfig):
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), get_safe_autocast_context(device, enabled=cfg.policy.use_amp):
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
+1 -7
View File
@@ -33,12 +33,7 @@ from .constants import (
REWARD,
)
from .decorators import check_if_already_connected, check_if_not_connected
from .device_utils import (
auto_select_torch_device,
get_safe_autocast_context,
get_safe_torch_device,
is_torch_device_available,
)
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from .import_utils import is_package_available, require_package
@@ -56,7 +51,6 @@ __all__ = [
"REWARD",
# Device utilities
"auto_select_torch_device",
"get_safe_autocast_context",
"get_safe_torch_device",
"is_torch_device_available",
# Import guards
-23
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import logging
from contextlib import AbstractContextManager, nullcontext
import torch
@@ -108,25 +107,3 @@ def is_amp_available(device: str):
return False
else:
raise ValueError(f"Unknown device '{device}.")
def get_safe_autocast_context(
device: str | torch.device,
*,
dtype: torch.dtype | None = None,
enabled: bool = True,
) -> AbstractContextManager:
"""Return a ``torch.autocast`` context, or a no-op when AMP is unsupported on ``device``.
Autocast is only entered on devices that support AMP (cuda, xpu, cpu); on mps and any
unknown device this falls back to ``nullcontext()`` so callers can request autocast
unconditionally without breaking on unsupported backends.
"""
device_type = device.type if isinstance(device, torch.device) else str(device).split(":", 1)[0]
try:
amp_ok = is_amp_available(device_type)
except ValueError:
amp_ok = False
if not enabled or not amp_ok:
return nullcontext()
return torch.autocast(device_type=device_type, dtype=dtype)
-40
View File
@@ -1,40 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import pytest
import torch
from lerobot.utils.device_utils import get_safe_autocast_context
@pytest.mark.parametrize(
("device", "enabled", "expect_autocast"),
[
("cpu", True, True), # AMP-capable device -> real autocast
(torch.device("cpu"), True, True), # accepts torch.device
("cpu", False, False), # explicitly disabled -> no-op
("mps", True, False), # AMP unsupported on mps -> no-op
("privateuseone", True, False), # unknown device -> safe no-op
],
)
def test_get_safe_autocast_context(device, enabled, expect_autocast):
ctx = get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=enabled)
if expect_autocast:
assert isinstance(ctx, torch.autocast)
with ctx:
assert torch.is_autocast_enabled("cpu")
else:
assert isinstance(ctx, nullcontext)
Generated
+859 -935
View File
File diff suppressed because it is too large Load Diff