diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a9306cc79..8ea741d08 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -19,20 +19,30 @@ import logging import math from collections import deque from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING, Literal import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers.models.auto import CONFIG_MAPPING -from transformers.models.gemma import modeling_gemma -from transformers.models.gemma.modeling_gemma import GemmaForCausalLM -from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig -from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE # Helper functions diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 6a87caead..8db75913c 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -19,20 +19,30 @@ import logging import math from collections import deque from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING, Literal import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers.models.auto import CONFIG_MAPPING -from transformers.models.gemma import modeling_gemma -from transformers.models.gemma.modeling_gemma import GemmaForCausalLM -from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None from lerobot.configs.policies import PreTrainedConfig -from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS # Helper functions diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 9ec58d3e8..c7523d167 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -22,7 +22,6 @@ import numpy as np import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.processor import ( @@ -39,6 +38,11 @@ from lerobot.processor import ( ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) @ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")