diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 6f1ed5173..353549c52 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -22,9 +22,11 @@ from typing import Literal import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer, GemmaForCausalLM, PaliGemmaForConditionalGeneration +from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM +from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize