From bba996ef8d321f5ccce874a03d48cd593392abee Mon Sep 17 00:00:00 2001 From: nv-sachdevkartik Date: Thu, 11 Jun 2026 18:03:28 +0000 Subject: [PATCH] groot: reuse lerobot get_device_from_parameters instead of inline lookup modeling_groot.py duplicated next(self.parameters()).device twice. LeRobot ships get_device_from_parameters in policies/utils.py (used by diffusion, vqbet, tdmpc, gaussian_actor). Reuse it for consistency with the framework. --- src/lerobot/policies/groot/modeling_groot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 0e92115d4..d0587a425 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -40,6 +40,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES from lerobot.utils.import_utils import require_package from ..pretrained import PreTrainedPolicy +from ..utils import get_device_from_parameters from .configuration_groot import ( GROOT_N1_7, GrootConfig, @@ -399,7 +400,7 @@ class GrootPolicy(PreTrainedPolicy): groot_inputs = self._filter_groot_inputs(batch, include_action=True) # Get device from model parameters - device = next(self.parameters()).device + device = get_device_from_parameters(self) # Run GR00T forward under bf16 autocast when enabled to reduce activation memory # Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts. @@ -437,7 +438,7 @@ class GrootPolicy(PreTrainedPolicy): ) # Get device from model parameters - device = next(self.parameters()).device + device = get_device_from_parameters(self) # Use bf16 autocast for inference to keep memory low and match backbone dtype with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):