mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
refactor(TokenizerProcessor): improve dependency handling and observation management
- Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements.
This commit is contained in:
@@ -5,14 +5,19 @@ Tokenizer processor for handling text tokenization in robot transitions.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.constants import OBS_LANGUAGE
|
from lerobot.constants import OBS_LANGUAGE
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
else:
|
||||||
|
AutoTokenizer = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -54,7 +59,7 @@ class TokenizerProcessor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer_name: str | None = None
|
tokenizer_name: str | None = None
|
||||||
tokenizer: AutoTokenizer | None = None
|
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||||
max_length: int = 512
|
max_length: int = 512
|
||||||
task_key: str = "task"
|
task_key: str = "task"
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
@@ -66,10 +71,18 @@ class TokenizerProcessor:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||||
|
if not _transformers_available:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'transformers' library is not installed. "
|
||||||
|
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
|
||||||
|
)
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
# Use provided tokenizer object directly
|
# Use provided tokenizer object directly
|
||||||
self._tokenizer = self.tokenizer
|
self._tokenizer = self.tokenizer
|
||||||
elif self.tokenizer_name is not None:
|
elif self.tokenizer_name is not None:
|
||||||
|
if AutoTokenizer is None:
|
||||||
|
raise ImportError("AutoTokenizer is not available")
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -125,9 +138,11 @@ class TokenizerProcessor:
|
|||||||
tokenized_prompt = self._tokenize_text(task)
|
tokenized_prompt = self._tokenize_text(task)
|
||||||
|
|
||||||
# Get or create observation dict
|
# Get or create observation dict
|
||||||
if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None:
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
transition[TransitionKey.OBSERVATION] = {}
|
if observation is None:
|
||||||
observation = transition[TransitionKey.OBSERVATION]
|
observation = {}
|
||||||
|
else:
|
||||||
|
observation = dict(observation) # Make a copy
|
||||||
|
|
||||||
# Add tokenized data to observation
|
# Add tokenized data to observation
|
||||||
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||||
@@ -135,6 +150,7 @@ class TokenizerProcessor:
|
|||||||
dtype=torch.bool
|
dtype=torch.bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
|||||||
|
|
||||||
|
|
||||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||||
|
_transformers_available = is_package_available("transformers")
|
||||||
_gym_xarm_available = is_package_available("gym_xarm")
|
_gym_xarm_available = is_package_available("gym_xarm")
|
||||||
_gym_aloha_available = is_package_available("gym_aloha")
|
_gym_aloha_available = is_package_available("gym_aloha")
|
||||||
_gym_pusht_available = is_package_available("gym_pusht")
|
_gym_pusht_available = is_package_available("gym_pusht")
|
||||||
|
|||||||
Reference in New Issue
Block a user