refactor(TokenizerProcessor): improve dependency handling and observation management

- Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility.
- Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed.
- Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures.
- Added error handling for missing transformers library, providing clear guidance for users on installation requirements.
This commit is contained in:
Adil Zouitine
2025-08-07 17:07:20 +02:00
parent e5ade5565d
commit 6a75b4761a
2 changed files with 23 additions and 6 deletions
+22 -6
View File
@@ -5,14 +5,19 @@ Tokenizer processor for handling text tokenization in robot transitions.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any
import torch
from transformers import AutoTokenizer
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
AutoTokenizer = None
@dataclass
@@ -54,7 +59,7 @@ class TokenizerProcessor:
"""
tokenizer_name: str | None = None
tokenizer: AutoTokenizer | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
@@ -66,10 +71,18 @@ class TokenizerProcessor:
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoTokenizer is None:
raise ImportError("AutoTokenizer is not available")
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
@@ -125,9 +138,11 @@ class TokenizerProcessor:
tokenized_prompt = self._tokenize_text(task)
# Get or create observation dict
if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None:
transition[TransitionKey.OBSERVATION] = {}
observation = transition[TransitionKey.OBSERVATION]
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
observation = {}
else:
observation = dict(observation) # Make a copy
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
@@ -135,6 +150,7 @@ class TokenizerProcessor:
dtype=torch.bool
)
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
+1
View File
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
_torch_available, _torch_version = is_package_available("torch", return_version=True)
_transformers_available = is_package_available("transformers")
_gym_xarm_available = is_package_available("gym_xarm")
_gym_aloha_available = is_package_available("gym_aloha")
_gym_pusht_available = is_package_available("gym_pusht")