fix(policies): transformers import for ci in PI0 & PI05 (#2039)

* fix(policies): transformers import for ci in PI0

* fix(policies): transformers import for ci in PI05
This commit is contained in:
Steven Palma
2025-09-25 14:56:20 +02:00
committed by GitHub
parent 1c2826e869
commit 77fc8cc4d9
3 changed files with 37 additions and 13 deletions
+16 -6
View File
@@ -19,20 +19,30 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import Literal
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
# Helper functions
+16 -6
View File
@@ -19,20 +19,30 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import Literal
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
# Helper functions
+5 -1
View File
@@ -22,7 +22,6 @@ import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
@@ -39,6 +38,11 @@ from lerobot.processor import (
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")