mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(TokenizerProcessor): improve dependency handling and observation management
- Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements.
This commit is contained in:
@@ -5,14 +5,19 @@ Tokenizer processor for handling text tokenization in robot transitions.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,7 +59,7 @@ class TokenizerProcessor:
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: AutoTokenizer | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
@@ -66,10 +71,18 @@ class TokenizerProcessor:
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self._tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -125,9 +138,11 @@ class TokenizerProcessor:
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Get or create observation dict
|
||||
if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None:
|
||||
transition[TransitionKey.OBSERVATION] = {}
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
observation = {}
|
||||
else:
|
||||
observation = dict(observation) # Make a copy
|
||||
|
||||
# Add tokenized data to observation
|
||||
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||
@@ -135,6 +150,7 @@ class TokenizerProcessor:
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||
return transition
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_gym_xarm_available = is_package_available("gym_xarm")
|
||||
_gym_aloha_available = is_package_available("gym_aloha")
|
||||
_gym_pusht_available = is_package_available("gym_pusht")
|
||||
|
||||
Reference in New Issue
Block a user