mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-15 07:19:51 +00:00
groot: reuse lerobot get_device_from_parameters instead of inline lookup
modeling_groot.py duplicated next(self.parameters()).device twice. LeRobot ships get_device_from_parameters in policies/utils.py (used by diffusion, vqbet, tdmpc, gaussian_actor). Reuse it for consistency with the framework.
This commit is contained in:
@@ -40,6 +40,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
@@ -399,7 +400,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
@@ -437,7 +438,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
|
||||
Reference in New Issue
Block a user