From 6a75b4761a9102d94ab61073923a43eeb89bd764 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 7 Aug 2025 17:07:20 +0200 Subject: [PATCH] refactor(TokenizerProcessor): improve dependency handling and observation management - Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements. --- src/lerobot/processor/tokenizer_processor.py | 28 +++++++++++++++----- src/lerobot/utils/import_utils.py | 1 + 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 4ec9fb351..796acf924 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -5,14 +5,19 @@ Tokenizer processor for handling text tokenization in robot transitions. from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from transformers import AutoTokenizer from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import OBS_LANGUAGE from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer +else: + AutoTokenizer = None @dataclass @@ -54,7 +59,7 @@ class TokenizerProcessor: """ tokenizer_name: str | None = None - tokenizer: AutoTokenizer | None = None + tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies max_length: int = 512 task_key: str = "task" padding_side: str = "right" @@ -66,10 +71,18 @@ class TokenizerProcessor: def __post_init__(self): """Initialize the tokenizer from the provided tokenizer or tokenizer name.""" + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor." + ) + if self.tokenizer is not None: # Use provided tokenizer object directly self._tokenizer = self.tokenizer elif self.tokenizer_name is not None: + if AutoTokenizer is None: + raise ImportError("AutoTokenizer is not available") self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) else: raise ValueError( @@ -125,9 +138,11 @@ class TokenizerProcessor: tokenized_prompt = self._tokenize_text(task) # Get or create observation dict - if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None: - transition[TransitionKey.OBSERVATION] = {} - observation = transition[TransitionKey.OBSERVATION] + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + observation = {} + else: + observation = dict(observation) # Make a copy # Add tokenized data to observation observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] @@ -135,6 +150,7 @@ class TokenizerProcessor: dtype=torch.bool ) + transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc] return transition def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 5c29b5a84..09e649372 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _torch_available, _torch_version = is_package_available("torch", return_version=True) +_transformers_available = is_package_available("transformers") _gym_xarm_available = is_package_available("gym_xarm") _gym_aloha_available = is_package_available("gym_aloha") _gym_pusht_available = is_package_available("gym_pusht")