groot: reuse lerobot get_device_from_parameters instead of inline lookup

modeling_groot.py duplicated next(self.parameters()).device twice. LeRobot
ships get_device_from_parameters in policies/utils.py (used by diffusion,
vqbet, tdmpc, gaussian_actor). Reuse it for consistency with the framework.
This commit is contained in:
nv-sachdevkartik
2026-06-11 18:03:28 +00:00
parent 162b07512a
commit bba996ef8d
+3 -2
View File
@@ -40,6 +40,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_7,
GrootConfig,
@@ -399,7 +400,7 @@ class GrootPolicy(PreTrainedPolicy):
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
# Get device from model parameters
device = next(self.parameters()).device
device = get_device_from_parameters(self)
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
@@ -437,7 +438,7 @@ class GrootPolicy(PreTrainedPolicy):
)
# Get device from model parameters
device = next(self.parameters()).device
device = get_device_from_parameters(self)
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):