diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 0e92115d4..d0587a425 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -40,6 +40,7 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES from lerobot.utils.import_utils import require_package from ..pretrained import PreTrainedPolicy +from ..utils import get_device_from_parameters from .configuration_groot import ( GROOT_N1_7, GrootConfig, @@ -399,7 +400,7 @@ class GrootPolicy(PreTrainedPolicy): groot_inputs = self._filter_groot_inputs(batch, include_action=True) # Get device from model parameters - device = next(self.parameters()).device + device = get_device_from_parameters(self) # Run GR00T forward under bf16 autocast when enabled to reduce activation memory # Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts. @@ -437,7 +438,7 @@ class GrootPolicy(PreTrainedPolicy): ) # Get device from model parameters - device = next(self.parameters()).device + device = get_device_from_parameters(self) # Use bf16 autocast for inference to keep memory low and match backbone dtype with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):