From 5326ffe77e6cc7a35a60da9294573b575328929c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 5 Aug 2025 10:53:08 +0200 Subject: [PATCH] feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. --- src/lerobot/configs/types.py | 1 + src/lerobot/constants.py | 1 + src/lerobot/policies/__init__.py | 11 + src/lerobot/policies/act/processor_act.py | 12 +- .../policies/diffusion/processor_diffusion.py | 12 +- src/lerobot/policies/factory.py | 6 +- src/lerobot/policies/pi0/modeling_pi0.py | 35 +- src/lerobot/policies/pi0/processor_pi0.py | 85 +- .../policies/pi0fast/processor_pi0fast.py | 12 +- src/lerobot/policies/sac/processor_sac.py | 12 +- .../sac/reward_model/processor_classifier.py | 4 +- .../policies/smolvla/modeling_smolvla.py | 150 +-- .../policies/smolvla/processor_smolvla.py | 71 +- src/lerobot/policies/tdmpc/processor_tdmpc.py | 12 +- src/lerobot/policies/vqbet/processor_vqbet.py | 12 +- src/lerobot/processor/__init__.py | 2 + src/lerobot/processor/batch_processor.py | 12 + src/lerobot/processor/device_processor.py | 83 +- src/lerobot/processor/normalize_processor.py | 76 +- src/lerobot/processor/pipeline.py | 18 +- src/lerobot/processor/tokenizer_processor.py | 210 +++++ tests/processor/test_batch_processor.py | 228 +++++ tests/processor/test_device_processor.py | 874 ++++++++++++++++++ tests/processor/test_normalize_processor.py | 267 ++++++ tests/processor/test_pipeline.py | 103 +++ tests/processor/test_tokenizer_processor.py | 699 ++++++++++++++ 26 files changed, 2776 insertions(+), 232 deletions(-) create mode 100644 src/lerobot/processor/tokenizer_processor.py create mode 100644 tests/processor/test_device_processor.py create mode 100644 tests/processor/test_tokenizer_processor.py diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 6040ff70b..322a7ea9b 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -24,6 +24,7 @@ class FeatureType(str, Enum): ENV = "ENV" ACTION = "ACTION" REWARD = "REWARD" + LANGUAGE = "LANGUAGE" class NormalizationMode(str, Enum): diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 30777239e..a502a9570 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -21,6 +21,7 @@ OBS_ENV_STATE = "observation.environment_state" OBS_STATE = "observation.state" OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" +OBS_LANGUAGE = "observation.language" ACTION = "action" REWARD = "next.reward" diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 9cb0f6234..9b9de9931 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,17 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi0.processor_pi0 import Pi0NewLineProcessor from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig +from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig + +__all__ = [ + "ACTConfig", + "DiffusionConfig", + "PI0Config", + "SmolVLAConfig", + "TDMPCConfig", + "VQBeTConfig", +] diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 64a1f6cc8..2ce01431c 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -17,7 +17,9 @@ import torch from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -28,15 +30,17 @@ def make_act_processor( config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index fccfe7064..f09f3c350 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -18,7 +18,9 @@ import torch from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -29,15 +31,17 @@ def make_diffusion_processor( config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 604de606e..9ea9fc267 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import TypedDict +from typing import Any, TypedDict from torch import nn from typing_extensions import Unpack @@ -111,6 +111,8 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_config_filename: str | None postprocessor_config_filename: str | None + preprocessor_overrides: dict[str, Any] | None + postprocessor_overrides: dict[str, Any] | None def make_processor( @@ -142,10 +144,12 @@ def make_processor( RobotProcessor.from_pretrained( source=pretrained_path, config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"), + overrides=kwargs.get("preprocessor_overrides", {}), ), RobotProcessor.from_pretrained( source=pretrained_path, config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"), + overrides=kwargs.get("postprocessor_overrides", {}), ), ) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index ba8324305..6b3ba834d 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -56,9 +56,8 @@ from collections import deque import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer -from lerobot.constants import ACTION, OBS_STATE +from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertConfig, @@ -226,16 +225,12 @@ class PI0Policy(PreTrainedPolicy): Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - # TODO(azouitine): Add tokenizer to pipeline - self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FlowMatching(config) self.reset() @@ -280,7 +275,8 @@ class PI0Policy(PreTrainedPolicy): if len(self._action_queue) == 0: images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"] + lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"] actions = self.model.sample_actions( images, img_masks, lang_tokens, lang_masks, state, noise=noise @@ -306,7 +302,8 @@ class PI0Policy(PreTrainedPolicy): images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"] + lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"] actions = self.prepare_action(batch) actions_is_pad = batch.get("action_is_pad") @@ -373,26 +370,6 @@ class PI0Policy(PreTrainedPolicy): return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - - # PaliGemma prompt has to end with a new line - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding="max_length", - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: @@ -458,7 +435,7 @@ class PI0FlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: PI0Config): super().__init__() self.config = config diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 50cc4d71f..4f67842c7 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -14,34 +14,107 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch +from lerobot.configs.types import PolicyFeature from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, RobotProcessor, ToBatchProcessor, + TokenizerProcessor, UnnormalizerProcessor, ) +from lerobot.processor.pipeline import ( + EnvTransition, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.processor.rename_processor import RenameProcessor + + +@ProcessorStepRegistry.register(name="pi0_new_line_processor") +class Pi0NewLineProcessor(ProcessorStep): + """Add a new line to the end of the task if it doesn't have one. + This is required for the PaliGemma tokenizer. + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Check if complementary_data exists + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None or "task" not in complementary_data: + return transition + + task = complementary_data["task"] + if task is None: + return transition + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Add tokenized task features to the feature contract.""" + return features + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {} def make_pi0_processor( config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: - input_steps = [ + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma + TokenizerProcessor( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessor(device=config.device), ] - output_steps = [ + + output_steps: list[ProcessorStep] = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] + return RobotProcessor(steps=input_steps, name="pi0_preprocessor"), RobotProcessor( steps=output_steps, name="pi0_postprocessor" ) diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index 50cc4d71f..fd6ff3d92 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -18,7 +18,9 @@ import torch from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -29,15 +31,17 @@ def make_pi0_processor( config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 26ebaf18b..7e2573bcd 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -19,7 +19,9 @@ import torch from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -30,15 +32,17 @@ def make_sac_processor( config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py index 394e85a64..084634a73 100644 --- a/src/lerobot/policies/sac/reward_model/processor_classifier.py +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -17,6 +17,7 @@ import torch from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.processor import ( + DeviceProcessor, IdentityProcessor, NormalizerProcessor, RobotProcessor, @@ -33,8 +34,9 @@ def make_classifier_processor( NormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), + DeviceProcessor(device=config.device), ] - output_steps = [IdentityProcessor()] + output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()] return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor( steps=output_steps, name="classifier_postprocessor" ) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index ff656febe..8df98c007 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -53,17 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") """ import math -import os -import re from collections import deque -import safetensors import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoProcessor -from lerobot.constants import ACTION, OBS_STATE +from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel @@ -72,102 +68,6 @@ from lerobot.policies.utils import ( ) from lerobot.utils.utils import get_safe_dtype -# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker -_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -def canonicalise(k: str) -> str: - """ - Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a - normalisation-buffer key. - """ - return _VARIANT_RE.sub(".buffer_", k) - - -def standardise_state_dict( - checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -) -> tuple[dict[str, torch.Tensor], list[str]]: - """ - • Re-keys `checkpoint ` so that every entry matches the *reference* key set. - • If several variant keys collapse to the same canonical name we keep the - first one and log the collision. - • Returns the new dict + a list of entries that could not be matched. - """ - out, collisions, unmatched = {}, {}, [] - - for k, v in checkpoint.items(): - canon = canonicalise(k) - if canon in ref_keys: - if canon in out: # duplicate after collapsing - collisions.setdefault(canon, []).append(k) - else: - out[canon] = v - else: - unmatched.append(k) - - if verbose: - for canon, variants in collisions.items(): - print(f"[standardise_state_dict] '{canon}' ← {variants}") - if unmatched: - print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - - out.update({k: checkpoint[k] for k in unmatched}) - return out, unmatched - - -def rename_checkpoint_keys(checkpoint: dict, rename_str: str): - """ - Renames keys in a checkpoint dictionary based on the given rename string. - - Args: - checkpoint (dict): The checkpoint dictionary. - rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - - Returns: - dict: The modified checkpoint with renamed keys. - """ - - rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - - new_checkpoint = {} - for k, v in checkpoint.items(): - for old_key, new_key in rename_dict.items(): - if old_key in k: - k = k.replace(old_key, new_key) - new_checkpoint[k] = v - return new_checkpoint - - -def load_smolvla( - model: torch.nn.Module, - filename: str | os.PathLike, - *, - device: str = "cpu", - checkpoint_keys_mapping: str = "", -) -> torch.nn.Module: - state_dict = safetensors.torch.load_file(filename, device=device) - - # Optional user-supplied renames (e.g. "model._orig_mod.//model.") - if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: - state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - - state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - - # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset - norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") - state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - - missing, unexpected = model.load_state_dict(state_dict, strict=False) - - if not all(key.startswith(norm_keys) for key in missing) or unexpected: - raise RuntimeError( - "SmolVLA %d missing / %d unexpected keys", - len(missing), - len(unexpected), - ) - - return model - def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -333,7 +233,6 @@ class SmolVLAPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer self.model = VLAFlowMatching(config) self.reset() @@ -343,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } - # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues - @classmethod - def _load_as_safetensor( - cls, - model: "SmolVLAPolicy", - model_file: str, - map_location: str, - strict: bool, - ): - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) - return load_smolvla( - model, - model_file, - device=map_location, - checkpoint_keys_mapping="model._orig_mod.//model.", - ) - def get_optim_params(self) -> dict: return self.parameters() @@ -375,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy): images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"] + lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"] actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) @@ -435,7 +318,8 @@ class SmolVLAPolicy(PreTrainedPolicy): images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"] + lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"] actions = self.prepare_action(batch) actions_is_pad = batch.get("actions_id_pad") loss_dict = {} @@ -499,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy): img_masks.append(mask) return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - - if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding=self.config.pad_language_to, - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index a61bd144b..373583e0a 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -13,30 +13,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch +from lerobot.configs.types import PolicyFeature from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, + TokenizerProcessor, UnnormalizerProcessor, ) +from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey def make_smolvla_processor( config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + SmolVLANewLineProcessor(), + TokenizerProcessor( + tokenizer_name=config.vlm_model_name, + padding=config.pad_language_to, + padding_side="right", + max_length=config.tokenizer_max_length, + ), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), @@ -44,3 +60,50 @@ def make_smolvla_processor( return RobotProcessor(steps=input_steps, name="smolvla_preprocessor"), RobotProcessor( steps=output_steps, name="smolvla_postprocessor" ) + + +@ProcessorStepRegistry.register(name="smolvla_new_line_processor") +class SmolVLANewLineProcessor(ProcessorStep): + """Add a new line to the end of the task if it doesn't have one.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Check if complementary_data exists + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None or "task" not in complementary_data: + return transition + + task = complementary_data["task"] + if task is None: + return transition + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Add tokenized task features to the feature contract.""" + return features + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {} diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index b7c43780f..833fcb55b 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -18,7 +18,9 @@ import torch from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -29,15 +31,17 @@ def make_tdmpc_processor( config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index 7a0ae84da..c82632787 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -19,7 +19,9 @@ import torch from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import ( + DeviceProcessor, NormalizerProcessor, + RenameProcessor, RobotProcessor, ToBatchProcessor, UnnormalizerProcessor, @@ -30,15 +32,17 @@ def make_vqbet_processor( config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[RobotProcessor, RobotProcessor]: input_steps = [ + RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys NormalizerProcessor( - features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - NormalizerProcessor( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, ), ToBatchProcessor(), + DeviceProcessor(device=config.device), ] output_steps = [ + DeviceProcessor(device="cpu"), UnnormalizerProcessor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index f2f117bf4..304229292 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -33,6 +33,7 @@ from .pipeline import ( TruncatedProcessor, ) from .rename_processor import RenameProcessor +from .tokenizer_processor import TokenizerProcessor __all__ = [ "ActionProcessor", @@ -51,6 +52,7 @@ __all__ = [ "RewardProcessor", "RobotProcessor", "ToBatchProcessor", + "TokenizerProcessor", "TransitionKey", "TruncatedProcessor", "VanillaObservationProcessor", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 1f356e9bf..40017760b 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -106,6 +106,18 @@ class ToBatchProcessor: if isinstance(task_value, str): complementary_data["task"] = [task_value] + # Process index field - add batch dim if 0D + if "index" in complementary_data: + index_value = complementary_data["index"] + if isinstance(index_value, Tensor) and index_value.dim() == 0: + complementary_data["index"] = index_value.unsqueeze(0) + + # Process task_index field - add batch dim if 0D + if "task_index" in complementary_data: + task_index_value = complementary_data["task_index"] + if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: + complementary_data["task_index"] = task_index_value.unsqueeze(0) + def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {} diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 0f00bb470..c5c86a696 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -19,24 +19,61 @@ from typing import Any import torch from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.utils.utils import get_safe_torch_device +@ProcessorStepRegistry.register("device_processor") @dataclass class DeviceProcessor: - """Processes transitions by moving tensors to the specified device. + """Processes transitions by moving tensors to the specified device and optionally converting float dtypes. This processor ensures that all tensors in the transition are moved to the - specified device (CPU or GPU) before they are returned. + specified device (CPU or GPU) before they are returned. It can also convert + floating-point tensors to a specified dtype while preserving non-float types + (int, long, bool, etc.). """ device: torch.device = "cpu" + float_dtype: str | None = None def __post_init__(self): self.device = get_safe_torch_device(self.device) self.non_blocking = "cuda" in str(self.device) + # Validate and convert float_dtype string to torch dtype + if self.float_dtype is not None: + dtype_mapping = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float": torch.float32, + "double": torch.float64, + } + + if self.float_dtype not in dtype_mapping: + available_dtypes = list(dtype_mapping.keys()) + raise ValueError( + f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}" + ) + + self._target_float_dtype = dtype_mapping[self.float_dtype] + else: + self._target_float_dtype = None + + def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """Process a tensor by moving to device and optionally converting float dtype.""" + # Move to device first + tensor = tensor.to(self.device, non_blocking=self.non_blocking) + + # Convert float dtype if specified and tensor is floating point + if self._target_float_dtype is not None and tensor.is_floating_point(): + tensor = tensor.to(dtype=self._target_float_dtype) + + return tensor + def __call__(self, transition: EnvTransition) -> EnvTransition: # Create a copy of the transition new_transition = transition.copy() @@ -45,7 +82,7 @@ class DeviceProcessor: observation = transition.get(TransitionKey.OBSERVATION) if observation is not None: new_observation = { - k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v + k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v for k, v in observation.items() } new_transition[TransitionKey.OBSERVATION] = new_observation @@ -53,30 +90,54 @@ class DeviceProcessor: # Process action tensor action = transition.get(TransitionKey.ACTION) if action is not None and isinstance(action, torch.Tensor): - new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + new_transition[TransitionKey.ACTION] = self._process_tensor(action) # Process reward tensor reward = transition.get(TransitionKey.REWARD) if reward is not None and isinstance(reward, torch.Tensor): - new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + new_transition[TransitionKey.REWARD] = self._process_tensor(reward) # Process done tensor done = transition.get(TransitionKey.DONE) if done is not None and isinstance(done, torch.Tensor): - new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + new_transition[TransitionKey.DONE] = self._process_tensor(done) # Process truncated tensor truncated = transition.get(TransitionKey.TRUNCATED) if truncated is not None and isinstance(truncated, torch.Tensor): - new_transition[TransitionKey.TRUNCATED] = truncated.to( - self.device, non_blocking=self.non_blocking - ) + new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated) + + # Process complementary data tensors + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is not None: + new_complementary_data = {} + + # Process all items in complementary_data + for key, value in complementary_data.items(): + if isinstance(value, torch.Tensor): + new_complementary_data[key] = self._process_tensor(value) + else: + new_complementary_data[key] = value + + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data return new_transition def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" - return {"device": self.device} + return {"device": self.device, "float_dtype": self.float_dtype} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 881381e6d..94390b004 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -116,7 +116,7 @@ class NormalizerProcessor: if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): self.normalize_keys = set(self.normalize_keys) - def _normalize_obs(self, observation): + def _normalize_obs(self, observation, normalized_info): if observation is None: return None @@ -138,6 +138,7 @@ class NormalizerProcessor: # Skip normalization if mode is IDENTITY if norm_mode is NormalizationMode.IDENTITY: + normalized_info[key] = "IDENTITY" continue # Skip if no stats available for this key @@ -156,16 +157,18 @@ class NormalizerProcessor: if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] processed[key] = (tensor - mean) / (std + self.eps) + normalized_info[key] = "MEAN_STD" elif norm_mode is NormalizationMode.MIN_MAX: if "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + normalized_info[key] = "MIN_MAX" else: raise ValueError(f"Unsupported normalization mode: {norm_mode}") return processed - def _normalize_action(self, action): + def _normalize_action(self, action, normalized_info): if action is None: return action @@ -174,6 +177,7 @@ class NormalizerProcessor: # Skip normalization if mode is IDENTITY if norm_mode is NormalizationMode.IDENTITY: + normalized_info["action"] = "IDENTITY" return action # Skip if no stats available for actions @@ -190,10 +194,12 @@ class NormalizerProcessor: if norm_mode is NormalizationMode.MEAN_STD: if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] + normalized_info["action"] = "MEAN_STD" return (tensor - mean) / (std + self.eps) elif norm_mode is NormalizationMode.MIN_MAX: if "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] + normalized_info["action"] = "MIN_MAX" return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 else: raise ValueError(f"Unsupported normalization mode: {norm_mode}") @@ -202,13 +208,24 @@ class NormalizerProcessor: raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._normalize_action(transition.get(TransitionKey.ACTION)) + # Track what was normalized + normalized_info = {} + + observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info) + action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info) # Create a new transition with normalized values new_transition = transition.copy() new_transition[TransitionKey.OBSERVATION] = observation new_transition[TransitionKey.ACTION] = action + + # Add normalization info to complementary data + if normalized_info: + comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["normalized_keys"] = normalized_info + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition def get_config(self) -> dict[str, Any]: @@ -289,7 +306,7 @@ class UnnormalizerProcessor: self.stats = self.stats or {} self._tensor_stats = _convert_stats_to_tensors(self.stats) - def _unnormalize_obs(self, observation): + def _unnormalize_obs(self, observation, unnormalized_info): if observation is None: return None keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] @@ -304,6 +321,7 @@ class UnnormalizerProcessor: # Skip unnormalization if mode is IDENTITY if norm_mode is NormalizationMode.IDENTITY: + unnormalized_info[key] = "IDENTITY" continue # Skip if no stats available for this key @@ -322,16 +340,18 @@ class UnnormalizerProcessor: if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] processed[key] = tensor * std + mean + unnormalized_info[key] = "MEAN_STD" elif norm_mode is NormalizationMode.MIN_MAX: if "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + unnormalized_info[key] = "MIN_MAX" else: raise ValueError(f"Unsupported normalization mode: {norm_mode}") return processed - def _unnormalize_action(self, action): + def _unnormalize_action(self, action, unnormalized_info): if action is None: return action @@ -340,6 +360,7 @@ class UnnormalizerProcessor: # Skip unnormalization if mode is IDENTITY if norm_mode is NormalizationMode.IDENTITY: + unnormalized_info["action"] = "IDENTITY" return action # Skip if no stats available for actions @@ -356,10 +377,12 @@ class UnnormalizerProcessor: if norm_mode is NormalizationMode.MEAN_STD: if "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] + unnormalized_info["action"] = "MEAN_STD" return tensor * std + mean elif norm_mode is NormalizationMode.MIN_MAX: if "min" in stats and "max" in stats: min_val, max_val = stats["min"], stats["max"] + unnormalized_info["action"] = "MIN_MAX" return (tensor + 1) / 2 * (max_val - min_val) + min_val else: raise ValueError(f"Unsupported normalization mode: {norm_mode}") @@ -368,13 +391,24 @@ class UnnormalizerProcessor: raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization") def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) + # Track what was unnormalized + unnormalized_info = {} + + observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info) + action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info) # Create a new transition with unnormalized values new_transition = transition.copy() new_transition[TransitionKey.OBSERVATION] = observation new_transition[TransitionKey.ACTION] = action + + # Add unnormalization info to complementary data + if unnormalized_info: + comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["unnormalized_keys"] = unnormalized_info + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition def get_config(self) -> dict[str, Any]: @@ -413,3 +447,29 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An step.stats = stats step._tensor_stats = _convert_stats_to_tensors(stats) return robot_processor + + +def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: + """Rename keys in the stats dictionary according to the provided mapping. + + Args: + stats: The statistics dictionary with structure {feature_key: {stat_name: value}} + rename_map: Dictionary mapping old key names to new key names + + Returns: + A new stats dictionary with renamed keys + + Example: + >>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}} + >>> rename_map = {"observation.state": "observation.robot_state"} + >>> new_stats = rename_stats(stats, rename_map) + >>> # new_stats will have "observation.robot_state" instead of "observation.state" + """ + renamed_stats = {} + + for old_key, sub_stats in stats.items(): + # Use the new key if it exists in the rename map, otherwise keep the old key + new_key = rename_map.get(old_key, old_key) + renamed_stats[new_key] = deepcopy(sub_stats) + + return renamed_stats diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 6e1b2a2cb..7683cc25c 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -201,10 +201,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} observation = observation_keys if observation_keys else None - # Extract padding and task keys for complementary data + # Extract padding, task, index, and task_index keys for complementary data pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} - complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} + index_key = {"index": batch["index"]} if "index" in batch else {} + task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + complementary_data = ( + {**pad_keys, **task_key, **index_key, **task_index_key} + if pad_keys or task_key or index_key or task_index_key + else {} + ) transition: EnvTransition = { TransitionKey.OBSERVATION: observation, @@ -231,7 +237,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: "info": transition.get(TransitionKey.INFO, {}), } - # Add padding and task data from complementary_data + # Add padding, task, index, and task_index data from complementary_data complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data: pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} @@ -240,6 +246,12 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: if "task" in complementary_data: batch["task"] = complementary_data["task"] + if "index" in complementary_data: + batch["index"] = complementary_data["index"] + + if "task_index" in complementary_data: + batch["task_index"] = complementary_data["task_index"] + # Handle observation - flatten dict to observation.* keys if it's a dict observation = transition.get(TransitionKey.OBSERVATION) if isinstance(observation, dict): diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py new file mode 100644 index 000000000..c7086d6ce --- /dev/null +++ b/src/lerobot/processor/tokenizer_processor.py @@ -0,0 +1,210 @@ +""" +Tokenizer processor for handling text tokenization in robot transitions. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import torch +from transformers import AutoTokenizer + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +@dataclass +@ProcessorStepRegistry.register(name="tokenizer_processor") +class TokenizerProcessor: + """Tokenizes text tasks in complementary data using a huggingface tokenizer. + + This processor handles tokenization of task strings found in the complementary_data + using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions + to the observation data for model processing while preserving the original task string. + + The processor supports both single strings and lists of strings as task inputs. + + Args: + tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub + (e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used + with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored. + tokenizer: A tokenizer object (e.g., from transformers library) that implements + the __call__ method. If provided, tokenizer_name is ignored. This parameter + is not serialized and must be provided via overrides when loading. + max_length: Maximum sequence length for tokenization. Defaults to 512. + task_key: Key in complementary_data containing the task text. Defaults to "task". + padding: Padding strategy for tokenization. Defaults to "max_length". + truncation: Whether to truncate sequences longer than max_length. Defaults to True. + + Examples: + Using tokenizer name (auto-loaded): + ```python + processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128) + ``` + + Using custom tokenizer object: + ```python + from transformers import AutoTokenizer + + custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128) + ``` + """ + + tokenizer_name: str | None = None + tokenizer: AutoTokenizer | None = None + max_length: int = 512 + task_key: str = "task" + padding_side: str = "right" + padding: str = "max_length" + truncation: bool = True + + # Internal tokenizer instance (not serialized) + _tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """Initialize the tokenizer from the provided tokenizer or tokenizer name.""" + if self.tokenizer is not None: + # Use provided tokenizer object directly + self._tokenizer = self.tokenizer + elif self.tokenizer_name is not None: + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + else: + raise ValueError( + "Either 'tokenizer' or 'tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + def get_task(self, transition: EnvTransition) -> list[str] | None: + """Extract and normalize task from complementary data. + + Args: + transition: Input transition containing complementary_data. + + Returns: + List of task strings if task is present, None otherwise. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + if self.task_key not in complementary_data: + return None + + task = complementary_data[self.task_key] + if task is None: + return None + + # Convert to list of strings + if isinstance(task, str): + return [task] + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + return task + + return None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process the transition by tokenizing the task text. + + Args: + transition: Input transition containing complementary_data with task text. + + Returns: + Modified transition with tokenized task added to observation. + + Raises: + ValueError: If tokenizer initialization failed. + """ + task = self.get_task(transition) + if task is None: + return transition + + # Tokenize the task + tokenized_prompt = self._tokenize_text(task) + + # Get or create observation dict + if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None: + transition[TransitionKey.OBSERVATION] = {} + observation = transition[TransitionKey.OBSERVATION] + + # Add tokenized data to observation + observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] + observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to( + dtype=torch.bool + ) + + return transition + + def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: + """Tokenize text using the configured tokenizer. + + Args: + text: Text string or list of strings to tokenize. + + Returns: + Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'. + """ + return self._tokenizer( + text, + max_length=self.max_length, + truncation=self.truncation, + padding=self.padding, + padding_side=self.padding_side, + return_tensors="pt", + ) + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization. + + Note: Only tokenizer_name is saved, not the tokenizer object itself. + When loading, provide the tokenizer via overrides if needed. + """ + config = { + "max_length": self.max_length, + "task_key": self.task_key, + "padding_side": self.padding_side, + "padding": self.padding, + "truncation": self.truncation, + } + + # Only include tokenizer_name if it was used (not when tokenizer object was provided) + if self.tokenizer_name is not None: + config["tokenizer_name"] = self.tokenizer_name + + return config + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return state dictionary (empty for this processor).""" + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load state dictionary (no-op for this processor).""" + pass + + def reset(self) -> None: + """Reset processor state (no-op for this processor).""" + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Add tokenized task features to the feature contract. + + Args: + features: Input feature dictionary. + + Returns: + Updated feature dictionary with tokenized task features added. + """ + # Add features for tokenized output if they don't exist + # Standard tokenizer output includes tokens and attention_mask + tokens_key = f"{OBS_LANGUAGE}.tokens" + attention_mask_key = f"{OBS_LANGUAGE}.attention_mask" + + if tokens_key not in features: + features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) + + if attention_mask_key not in features: + features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) + + return features diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index 3d8cb8d49..c9c4cd1dd 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -899,3 +899,231 @@ def test_task_preserves_other_keys(): assert processed_comp_data["motor_id"] == "motor_456" assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"} assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + + +# Index and task_index specific tests +def test_index_scalar_to_1d(): + """Test that 0D index tensor gets unsqueezed to 1D.""" + processor = ToBatchProcessor() + + # Create 0D index tensor (scalar) + index_0d = torch.tensor(42, dtype=torch.int64) + complementary_data = {"index": index_0d} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["index"][0] == 42 + + +def test_task_index_scalar_to_1d(): + """Test that 0D task_index tensor gets unsqueezed to 1D.""" + processor = ToBatchProcessor() + + # Create 0D task_index tensor (scalar) + task_index_0d = torch.tensor(7, dtype=torch.int64) + complementary_data = {"task_index": task_index_0d} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"].dtype == torch.int64 + assert processed_comp_data["task_index"][0] == 7 + + +def test_index_and_task_index_together(): + """Test processing both index and task_index together.""" + processor = ToBatchProcessor() + + # Create 0D tensors for both + index_0d = torch.tensor(100, dtype=torch.int64) + task_index_0d = torch.tensor(3, dtype=torch.int64) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + "task": "pick_object", + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 100 + + # Check task_index + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 3 + + # Check task is also processed + assert processed_comp_data["task"] == ["pick_object"] + + +def test_index_already_batched(): + """Test that already batched index tensors remain unchanged.""" + processor = ToBatchProcessor() + + # Create already batched tensors + index_1d = torch.tensor([42], dtype=torch.int64) + index_2d = torch.tensor([[42, 43]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"index": index_1d} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d) + + # Test 2D + complementary_data = {"index": index_2d} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d) + + +def test_task_index_already_batched(): + """Test that already batched task_index tensors remain unchanged.""" + processor = ToBatchProcessor() + + # Create already batched tensors + task_index_1d = torch.tensor([7], dtype=torch.int64) + task_index_2d = torch.tensor([[7, 8]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"task_index": task_index_1d} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d) + + # Test 2D + complementary_data = {"task_index": task_index_2d} + transition = create_transition(complementary_data=complementary_data) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d) + + +def test_index_non_tensor_unchanged(): + """Test that non-tensor index values remain unchanged.""" + processor = ToBatchProcessor() + + complementary_data = { + "index": 42, # Plain int, not tensor + "task_index": [1, 2, 3], # List, not tensor + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"] == 42 + assert processed_comp_data["task_index"] == [1, 2, 3] + + +def test_index_dtype_preservation(): + """Test that index and task_index dtype is preserved during processing.""" + processor = ToBatchProcessor() + + # Test different dtypes + dtypes = [torch.int32, torch.int64, torch.long] + + for dtype in dtypes: + index_0d = torch.tensor(42, dtype=dtype) + task_index_0d = torch.tensor(7, dtype=dtype) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].dtype == dtype + assert processed_comp_data["task_index"].dtype == dtype + + +def test_index_with_full_transition(): + """Test index/task_index processing with full transition data.""" + processor = ToBatchProcessor() + + # Create full transition with all components + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + complementary_data = { + "task": "navigate_to_goal", + "index": torch.tensor(1000, dtype=torch.int64), + "task_index": torch.tensor(5, dtype=torch.int64), + "episode_id": 123, + } + + transition = create_transition( + observation=observation, + action=action, + reward=0.5, + done=False, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components are processed correctly + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate_to_goal"] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 1000 + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 5 + assert processed_comp_data["episode_id"] == 123 # Non-tensor field unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_index_device_compatibility(): + """Test processor works with index/task_index tensors on different devices.""" + processor = ToBatchProcessor() + + # Create tensors on GPU + index_0d = torch.tensor(42, dtype=torch.int64, device="cuda") + task_index_0d = torch.tensor(7, dtype=torch.int64, device="cuda") + + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check shapes and that tensors stayed on GPU + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + +def test_empty_index_tensor(): + """Test handling of empty index tensors.""" + processor = ToBatchProcessor() + + # Empty 0D tensor doesn't make sense, but test empty 1D + index_empty = torch.tensor([], dtype=torch.int64) + complementary_data = {"index": index_empty} + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should remain unchanged (already 1D) + assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,) diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py new file mode 100644 index 000000000..5ca818c32 --- /dev/null +++ b/tests/processor/test_device_processor.py @@ -0,0 +1,874 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.processor import DeviceProcessor, RobotProcessor +from lerobot.processor.pipeline import TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper function to create a transition dictionary.""" + transition = {} + if observation is not None: + transition[TransitionKey.OBSERVATION] = observation + if action is not None: + transition[TransitionKey.ACTION] = action + if reward is not None: + transition[TransitionKey.REWARD] = reward + if done is not None: + transition[TransitionKey.DONE] = done + if truncated is not None: + transition[TransitionKey.TRUNCATED] = truncated + if info is not None: + transition[TransitionKey.INFO] = info + if complementary_data is not None: + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return transition + + +def test_basic_functionality(): + """Test basic device processor functionality on CPU.""" + processor = DeviceProcessor(device="cpu") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CPU + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.ACTION].device.type == "cpu" + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_functionality(): + """Test device processor functionality on CUDA.""" + processor = DeviceProcessor(device="cuda") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + assert result[TransitionKey.TRUNCATED].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_specific_cuda_device(): + """Test device processor with specific CUDA device.""" + processor = DeviceProcessor(device="cuda:0") + + observation = {"observation.state": torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].device.index == 0 + + +def test_non_tensor_values(): + """Test that non-tensor values are preserved.""" + processor = DeviceProcessor(device="cpu") + + observation = { + "observation.state": torch.randn(10), + "observation.metadata": {"key": "value"}, # Non-tensor data + "observation.list": [1, 2, 3], # Non-tensor data + } + action = torch.randn(5) + info = {"episode": 1, "step": 42} + + transition = create_transition(observation=observation, action=action, info=info) + + result = processor(transition) + + # Check tensors are processed + assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.ACTION], torch.Tensor) + + # Check non-tensor values are preserved + assert result[TransitionKey.OBSERVATION]["observation.metadata"] == {"key": "value"} + assert result[TransitionKey.OBSERVATION]["observation.list"] == [1, 2, 3] + assert result[TransitionKey.INFO] == {"episode": 1, "step": 42} + + +def test_none_values(): + """Test handling of None values.""" + processor = DeviceProcessor(device="cpu") + + # Test with None observation + transition = create_transition(observation=None, action=torch.randn(5)) + result = processor(transition) + assert TransitionKey.OBSERVATION not in result + assert result[TransitionKey.ACTION].device.type == "cpu" + + # Test with None action + transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + result = processor(transition) + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert TransitionKey.ACTION not in result + + +def test_empty_observation(): + """Test handling of empty observation dictionary.""" + processor = DeviceProcessor(device="cpu") + + transition = create_transition(observation={}, action=torch.randn(5)) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_scalar_tensors(): + """Test handling of scalar tensors.""" + processor = DeviceProcessor(device="cpu") + + observation = {"observation.scalar": torch.tensor(1.5)} + action = torch.tensor(2.0) + reward = torch.tensor(0.5) + + transition = create_transition(observation=observation, action=action, reward=reward) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.scalar"].item() == 1.5 + assert result[TransitionKey.ACTION].item() == 2.0 + assert result[TransitionKey.REWARD].item() == 0.5 + + +def test_dtype_preservation(): + """Test that tensor dtypes are preserved.""" + processor = DeviceProcessor(device="cpu") + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float16) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_shape_preservation(): + """Test that tensor shapes are preserved.""" + processor = DeviceProcessor(device="cpu") + + observation = { + "observation.1d": torch.randn(10), + "observation.2d": torch.randn(5, 10), + "observation.3d": torch.randn(3, 224, 224), + "observation.4d": torch.randn(2, 3, 224, 224), + } + action = torch.randn(2, 5, 3) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.1d"].shape == (10,) + assert result[TransitionKey.OBSERVATION]["observation.2d"].shape == (5, 10) + assert result[TransitionKey.OBSERVATION]["observation.3d"].shape == (3, 224, 224) + assert result[TransitionKey.OBSERVATION]["observation.4d"].shape == (2, 3, 224, 224) + assert result[TransitionKey.ACTION].shape == (2, 5, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_devices(): + """Test handling of tensors already on different devices.""" + processor = DeviceProcessor(device="cuda") + + # Create tensors on different devices + observation = { + "observation.cpu": torch.randn(5), # CPU + "observation.cuda": torch.randn(5).cuda(), # Already on CUDA + } + action = torch.randn(3).cuda() # Already on CUDA + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # All should be on CUDA + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.cuda"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_non_blocking_flag(): + """Test that non_blocking flag is set correctly.""" + # CPU processor should have non_blocking=False + cpu_processor = DeviceProcessor(device="cpu") + assert cpu_processor.non_blocking is False + + # CUDA processor should have non_blocking=True + cuda_processor = DeviceProcessor(device="cuda") + assert cuda_processor.non_blocking is True + + cuda_0_processor = DeviceProcessor(device="cuda:0") + assert cuda_0_processor.non_blocking is True + + +def test_serialization_methods(): + """Test get_config, state_dict, and load_state_dict methods.""" + processor = DeviceProcessor(device="cuda") + + # Test get_config + config = processor.get_config() + assert config == {"device": "cuda", "float_dtype": None} + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should be no-op) + processor.load_state_dict({}) + assert processor.device == "cuda" + + # Test reset (should be no-op) + processor.reset() + assert processor.device == "cuda" + + +def test_feature_contract(): + """Test that feature_contract returns features unchanged.""" + processor = DeviceProcessor(device="cpu") + + features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + + result = processor.feature_contract(features) + assert result == features + assert result is features # Should return the same object + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor.""" + from lerobot.processor import ToBatchProcessor + + # Create a pipeline with DeviceProcessor + device_processor = DeviceProcessor(device="cpu") + batch_processor = ToBatchProcessor() + + processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline") + + # Create test data + observation = {"observation.state": torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are batched and on correct device + assert result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1 # Batched + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.ACTION].shape[0] == 1 # Batched + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with DeviceProcessor.""" + processor = DeviceProcessor(device="cuda:0", float_dtype="float16") + robot_processor = RobotProcessor(steps=[processor], name="device_test_processor") + + with tempfile.TemporaryDirectory() as tmpdir: + # Save + robot_processor.save_pretrained(tmpdir) + + # Load + loaded_processor = RobotProcessor.from_pretrained(tmpdir) + + assert len(loaded_processor.steps) == 1 + loaded_device_processor = loaded_processor.steps[0] + assert isinstance(loaded_device_processor, DeviceProcessor) + assert loaded_device_processor.device == "cuda:0" + assert loaded_device_processor.float_dtype == "float16" + + +def test_registry_functionality(): + """Test that DeviceProcessor is properly registered.""" + from lerobot.processor.pipeline import ProcessorStepRegistry + + # Check that DeviceProcessor is registered + registered_class = ProcessorStepRegistry.get("device_processor") + assert registered_class is DeviceProcessor + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_performance_with_large_tensors(): + """Test performance with large tensors and non_blocking flag.""" + processor = DeviceProcessor(device="cuda") + + # Create large tensors + observation = { + "observation.large_image": torch.randn(10, 3, 512, 512), # Large image batch + "observation.features": torch.randn(10, 2048), # Large feature vector + } + action = torch.randn(10, 100) # Large action space + + transition = create_transition(observation=observation, action=action) + + # Process should not raise any errors + result = processor(transition) + + # Verify all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.large_image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.features"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_reward_done_truncated_types(): + """Test handling of different types for reward, done, and truncated.""" + processor = DeviceProcessor(device="cpu") + + # Test with scalar values (not tensors) + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=1.0, # float + done=False, # bool + truncated=True, # bool + ) + + result = processor(transition) + + # Non-tensor values should be preserved as-is + assert result[TransitionKey.REWARD] == 1.0 + assert result[TransitionKey.DONE] is False + assert result[TransitionKey.TRUNCATED] is True + + # Test with tensor values + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=torch.tensor(1.0), + done=torch.tensor(False), + truncated=torch.tensor(True), + ) + + result = processor(transition) + + # Tensor values should be moved to device + assert isinstance(result[TransitionKey.REWARD], torch.Tensor) + assert isinstance(result[TransitionKey.DONE], torch.Tensor) + assert isinstance(result[TransitionKey.TRUNCATED], torch.Tensor) + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +def test_complementary_data_preserved(): + """Test that complementary_data is preserved unchanged.""" + processor = DeviceProcessor(device="cpu") + + complementary_data = { + "task": "pick_object", + "episode_id": 42, + "metadata": {"sensor": "camera_1"}, + "observation_is_pad": torch.tensor([False, False, True]), # This should be moved to device + } + + transition = create_transition( + observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + ) + + result = processor(transition) + + # Check that complementary_data is preserved + assert TransitionKey.COMPLEMENTARY_DATA in result + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object" + assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42 + assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"} + # Note: Currently DeviceProcessor doesn't process tensors in complementary_data + # This is intentional as complementary_data is typically metadata + + +def test_float_dtype_conversion(): + """Test float dtype conversion functionality.""" + processor = DeviceProcessor(device="cpu", float_dtype="float16") + + # Create tensors of different types + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float32) + reward = torch.tensor(1.0, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action, reward=reward) + result = processor(transition) + + # Check that float tensors are converted to float16 + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check that non-float tensors are preserved + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + + +def test_float_dtype_none(): + """Test that when float_dtype is None, no dtype conversion occurs.""" + processor = DeviceProcessor(device="cpu", float_dtype=None) + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that dtypes are preserved when float_dtype is None + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_bfloat16(): + """Test conversion to bfloat16.""" + processor = DeviceProcessor(device="cpu", float_dtype="bfloat16") + + observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.ACTION].dtype == torch.bfloat16 + + +def test_float_dtype_float64(): + """Test conversion to float64.""" + processor = DeviceProcessor(device="cpu", float_dtype="float64") + + observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + action = torch.randn(3, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_invalid(): + """Test that invalid float_dtype raises ValueError.""" + with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"): + DeviceProcessor(device="cpu", float_dtype="invalid_dtype") + + +def test_float_dtype_aliases(): + """Test that dtype aliases work correctly.""" + # Test 'half' alias for float16 + processor_half = DeviceProcessor(device="cpu", float_dtype="half") + assert processor_half._target_float_dtype == torch.float16 + + # Test 'float' alias for float32 + processor_float = DeviceProcessor(device="cpu", float_dtype="float") + assert processor_float._target_float_dtype == torch.float32 + + # Test 'double' alias for float64 + processor_double = DeviceProcessor(device="cpu", float_dtype="double") + assert processor_double._target_float_dtype == torch.float64 + + +def test_float_dtype_with_mixed_tensors(): + """Test float dtype conversion with mixed tensor types.""" + processor = DeviceProcessor(device="cpu", float_dtype="float32") + + observation = { + "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert + "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert + } + action = torch.randn(5, dtype=torch.float16) # Should convert + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check conversions + assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged + assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted + + +def test_float_dtype_serialization(): + """Test that float_dtype is properly serialized in get_config.""" + processor = DeviceProcessor(device="cuda", float_dtype="float16") + config = processor.get_config() + + assert config == {"device": "cuda", "float_dtype": "float16"} + + # Test with None float_dtype + processor_none = DeviceProcessor(device="cpu", float_dtype=None) + config_none = processor_none.get_config() + + assert config_none == {"device": "cpu", "float_dtype": None} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_float_dtype_with_cuda(): + """Test float dtype conversion combined with CUDA device.""" + processor = DeviceProcessor(device="cuda", float_dtype="float16") + + # Create tensors on CPU with different dtypes + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.int64": torch.tensor([1, 2, 3], dtype=torch.int64), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are on CUDA and float types are converted + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + + assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 # Unchanged + + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_complementary_data_index_fields(): + """Test processing of index and task_index fields in complementary_data.""" + processor = DeviceProcessor(device="cpu") + + # Create transition with index and task_index in complementary_data + complementary_data = { + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "episode_id": 123, # Non-tensor field + } + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check that tensors in complementary_data are processed + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index tensor + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert processed_comp_data["index"].device.type == "cpu" + assert torch.equal(processed_comp_data["index"], complementary_data["index"]) + + # Check task_index tensor + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + assert processed_comp_data["task_index"].device.type == "cpu" + assert torch.equal(processed_comp_data["task_index"], complementary_data["task_index"]) + + # Check non-tensor fields remain unchanged + assert processed_comp_data["task"] == ["pick_cube"] + assert processed_comp_data["episode_id"] == 123 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_index_fields_cuda(): + """Test moving index and task_index fields to CUDA.""" + processor = DeviceProcessor(device="cuda:0") + + # Create CPU tensors + complementary_data = { + "index": torch.tensor([100, 101], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors moved to CUDA + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["index"].device.index == 0 + assert processed_comp_data["task_index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.index == 0 + + +def test_complementary_data_without_index_fields(): + """Test that complementary_data without index/task_index fields works correctly.""" + processor = DeviceProcessor(device="cpu") + + complementary_data = { + "task": ["navigate"], + "episode_id": 456, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should process without errors and preserve non-tensor fields + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate"] + assert processed_comp_data["episode_id"] == 456 + + +def test_complementary_data_mixed_tensors(): + """Test complementary_data with mix of tensors and non-tensors.""" + processor = DeviceProcessor(device="cpu") + + complementary_data = { + "task": ["pick_and_place"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "metrics": [1.0, 2.0, 3.0], # List, not tensor + "config": {"speed": "fast"}, # Dict + "episode_id": 789, # Int + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors are processed + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + + # Check non-tensors remain unchanged + assert processed_comp_data["task"] == ["pick_and_place"] + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + assert processed_comp_data["config"] == {"speed": "fast"} + assert processed_comp_data["episode_id"] == 789 + + +def test_complementary_data_float_dtype_conversion(): + """Test that float dtype conversion doesn't affect int tensors in complementary_data.""" + processor = DeviceProcessor(device="cpu", float_dtype="float16") + + complementary_data = { + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "float_tensor": torch.tensor([1.5, 2.5], dtype=torch.float32), # Should be converted + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Int tensors should keep their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + # Float tensor should be converted + assert processed_comp_data["float_tensor"].dtype == torch.float16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_full_pipeline_cuda(): + """Test full transition with complementary_data on CUDA.""" + processor = DeviceProcessor(device="cuda:0", float_dtype="float16") + + # Create full transition with mixed CPU tensors + observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + action = torch.randn(1, 4, dtype=torch.float32) + reward = torch.tensor(1.5, dtype=torch.float32) + done = torch.tensor(False) + complementary_data = { + "task": ["reach_target"], + "index": torch.tensor([1000], dtype=torch.int64), + "task_index": torch.tensor([10], dtype=torch.int64), + } + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components moved to CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + + # Check complementary_data tensors + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + # Check float conversion happened for float tensors + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check int tensors kept their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + +def test_complementary_data_empty(): + """Test empty complementary_data handling.""" + processor = DeviceProcessor(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data={}, + ) + + result = processor(transition) + + # Should have empty dict + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_none(): + """Test None complementary_data handling.""" + processor = DeviceProcessor(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data=None, + ) + + result = processor(transition) + + # Complementary data should not be in the result (same as input) + assert TransitionKey.COMPLEMENTARY_DATA not in result + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_policy_processor_integration(): + """Test integration with policy processors - input on GPU, output on CPU.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor + + # Create features and stats + features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + + stats = { + "observation.state": {"mean": torch.zeros(10), "std": torch.ones(10)}, + "action": {"mean": torch.zeros(5), "std": torch.ones(5)}, + } + + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD} + + # Create input processor (preprocessor) that moves to GPU + input_processor = RobotProcessor( + steps=[ + NormalizerProcessor(features=features, norm_map=norm_map, stats=stats), + ToBatchProcessor(), + DeviceProcessor(device="cuda"), + ], + name="test_preprocessor", + ) + + # Create output processor (postprocessor) that moves to CPU + output_processor = RobotProcessor( + steps=[ + DeviceProcessor(device="cpu"), + UnnormalizerProcessor(features={"action": features["action"]}, norm_map=norm_map, stats=stats), + ], + name="test_postprocessor", + ) + + # Test data on CPU + observation = {"observation.state": torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation=observation, action=action) + + # Process through input processor + input_result = input_processor(transition) + + # Verify tensors are on GPU and batched + assert input_result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert input_result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1 + assert input_result[TransitionKey.ACTION].device.type == "cuda" + assert input_result[TransitionKey.ACTION].shape[0] == 1 + + # Simulate model output on GPU + model_output = create_transition(action=torch.randn(1, 5).cuda()) + + # Process through output processor + output_result = output_processor(model_output) + + # Verify action is back on CPU and unnormalized + assert output_result[TransitionKey.ACTION].device.type == "cpu" + assert output_result[TransitionKey.ACTION].shape == (1, 5) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5cc396f00..6fc60b49b 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1260,6 +1260,273 @@ def test_hotswap_stats_with_different_data_types(): torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) +def test_normalization_info_tracking(): + """Test that normalization info is tracked in complementary_data.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 1.0]), + }, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process the transition + normalized_transition = normalizer(transition) + + # Check that normalization info is added + comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is not None + assert "normalized_keys" in comp_data + + norm_info = comp_data["normalized_keys"] + assert norm_info["observation.image"] == "MEAN_STD" + assert norm_info["observation.state"] == "MIN_MAX" + assert norm_info["action"] == "IDENTITY" + + +def test_unnormalization_info_tracking(): + """Test that unnormalization info is tracked in complementary_data.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "action": { + "min": np.array([-1.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + action = torch.tensor([0.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process the transition + unnormalized_transition = unnormalizer(transition) + + # Check that unnormalization info is added + comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is not None + assert "unnormalized_keys" in comp_data + + unnorm_info = comp_data["unnormalized_keys"] + assert unnorm_info["observation.image"] == "MEAN_STD" + assert unnorm_info["action"] == "MIN_MAX" + + +def test_normalization_info_with_missing_stats(): + """Test normalization info when stats are missing for some keys.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + # Only provide stats for image, not state + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + # Process the transition + normalized_transition = normalizer(transition) + + # Check that only keys with stats are in normalization info + comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is not None + assert "normalized_keys" in comp_data + + norm_info = comp_data["normalized_keys"] + assert norm_info["observation.image"] == "MEAN_STD" + # State should not be in the normalization info since it has no stats + assert "observation.state" not in norm_info + + +def test_normalization_info_with_selective_keys(): + """Test normalization info with selective normalization.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + # Only normalize image + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"} + ) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + # Process the transition + normalized_transition = normalizer(transition) + + # Check that only selected keys are in normalization info + comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is not None + assert "normalized_keys" in comp_data + + norm_info = comp_data["normalized_keys"] + assert norm_info["observation.image"] == "MEAN_STD" + # State should not be in the normalization info since it wasn't in normalize_keys + assert "observation.state" not in norm_info + + +def test_normalization_info_preserved_in_pipeline(): + """Test that normalization info is preserved when using RobotProcessor pipeline.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "action": { + "min": np.array([-1.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + # Create pipeline + pipeline = RobotProcessor([normalizer, unnormalizer]) + + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + action = torch.tensor([0.5, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process through pipeline + result = pipeline(transition) + + # Check that both normalization and unnormalization info are present + comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is not None + assert "normalized_keys" in comp_data + assert "unnormalized_keys" in comp_data + + # Check normalization info + norm_info = comp_data["normalized_keys"] + assert norm_info["observation.image"] == "MEAN_STD" + assert norm_info["action"] == "MIN_MAX" + + # Check unnormalization info + unnorm_info = comp_data["unnormalized_keys"] + assert unnorm_info["observation.image"] == "MEAN_STD" + assert unnorm_info["action"] == "MIN_MAX" + + +def test_normalization_info_empty_transition(): + """Test that no normalization info is added for empty transitions.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"min": [-1.0], "max": [1.0]}, + } + + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + # Empty transition + transition = create_transition() + + # Process the transition + normalized_transition = normalizer(transition) + + # Check that no normalization info is added + comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA) + assert comp_data is None or "normalized_keys" not in comp_data + + def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 5665d5a7d..26e865fad 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -1639,6 +1639,109 @@ def test_state_file_naming_with_multiple_processors(): assert loaded_post.steps[0].window_size == 10 +def test_default_batch_to_transition_with_index_fields(): + """Test that _default_batch_to_transition handles index and task_index fields correctly.""" + from lerobot.processor.pipeline import _default_batch_to_transition + + # Create batch with index and task_index fields + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "next.reward": 1.5, + "next.done": False, + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + } + + transition = _default_batch_to_transition(batch) + + # Check basic transition structure + assert TransitionKey.OBSERVATION in transition + assert TransitionKey.ACTION in transition + assert TransitionKey.COMPLEMENTARY_DATA in transition + + # Check that index and task_index are in complementary_data + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + assert "index" in comp_data + assert "task_index" in comp_data + assert "task" in comp_data + + # Verify values + assert torch.equal(comp_data["index"], batch["index"]) + assert torch.equal(comp_data["task_index"], batch["task_index"]) + assert comp_data["task"] == batch["task"] + + +def test_default_transition_to_batch_with_index_fields(): + """Test that _default_transition_to_batch handles index and task_index fields correctly.""" + from lerobot.processor.pipeline import _default_transition_to_batch + + # Create transition with index and task_index in complementary_data + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + reward=1.5, + done=False, + complementary_data={ + "task": ["navigate"], + "index": torch.tensor([100], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + }, + ) + + batch = _default_transition_to_batch(transition) + + # Check that index and task_index are in the batch + assert "index" in batch + assert "task_index" in batch + assert "task" in batch + + # Verify values + assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"]) + assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"]) + assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"] + + +def test_batch_to_transition_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + from lerobot.processor.pipeline import _default_batch_to_transition + + # Batch without index/task_index + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "task": ["pick_cube"], + } + + transition = _default_batch_to_transition(batch) + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + + # Should have task but not index/task_index + assert "task" in comp_data + assert "index" not in comp_data + assert "task_index" not in comp_data + + +def test_transition_to_batch_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + from lerobot.processor.pipeline import _default_transition_to_batch + + # Transition without index/task_index + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data={"task": ["navigate"]}, + ) + + batch = _default_transition_to_batch(transition) + + # Should have task but not index/task_index + assert "task" in batch + assert "index" not in batch + assert "task_index" not in batch + + def test_override_with_device_strings(): """Test overriding device parameters with string values.""" diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py new file mode 100644 index 000000000..b0c235b68 --- /dev/null +++ b/tests/processor/test_tokenizer_processor.py @@ -0,0 +1,699 @@ +""" +Tests for the TokenizerProcessor class. +""" + +import tempfile +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE +from lerobot.processor.pipeline import RobotProcessor, TransitionKey +from lerobot.processor.tokenizer_processor import TokenizerProcessor + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper function to create test transitions.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +class MockTokenizer: + """Mock tokenizer for testing that mimics transformers tokenizer interface.""" + + def __init__(self, vocab_size: int = 1000): + self.vocab_size = vocab_size + + def __call__( + self, + text: str | list[str], + max_length: int = 512, + truncation: bool = True, + padding: str = "max_length", + padding_side: str = "right", + return_tensors: str = "pt", + **kwargs, + ) -> dict[str, torch.Tensor]: + """Mock tokenization that returns deterministic tokens based on text.""" + if isinstance(text, str): + texts = [text] + else: + texts = text + + batch_size = len(texts) + + # Create mock input_ids and attention_mask + input_ids = torch.zeros(batch_size, max_length, dtype=torch.long) + attention_mask = torch.zeros(batch_size, max_length, dtype=torch.long) + + for i, txt in enumerate(texts): + # Simple mock: use hash of text to generate deterministic tokens + text_hash = hash(txt) % self.vocab_size + seq_len = min(len(txt.split()), max_length) + + # Fill input_ids with simple pattern based on text + for j in range(seq_len): + input_ids[i, j] = (text_hash + j) % self.vocab_size + + # Set attention mask for non-padded positions + attention_mask[i, :seq_len] = 1 + + result = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Return single sequence for single input to match transformers behavior + if len(texts) == 1: + result = {k: v.squeeze(0) for k, v in result.items()} + + return result + + +@pytest.fixture +def mock_tokenizer(): + """Provide a mock tokenizer for testing.""" + return MockTokenizer(vocab_size=100) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_basic_tokenization(mock_auto_tokenizer): + """Test basic string tokenization functionality.""" + # Mock AutoTokenizer.from_pretrained to return our mock tokenizer + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition(complementary_data={"task": "pick up the red cube"}) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +def test_basic_tokenization_with_tokenizer_object(): + """Test basic string tokenization functionality using tokenizer object directly.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition(complementary_data={"task": "pick up the red cube"}) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_list_of_strings_tokenization(mock_auto_tokenizer): + """Test tokenization of a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition(complementary_data={"task": ["pick up cube", "place on table"]}) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["pick up cube", "place on table"] + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert attention_mask.shape == (2, 8) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_keys(mock_auto_tokenizer): + """Test using custom task_key.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5) + + transition = create_transition(complementary_data={"instruction": "move forward"}) + + result = processor(transition) + + # Check that tokens are stored in observation regardless of task_key + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (5,) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_complementary_data(mock_auto_tokenizer): + """Test handling of None complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + transition = create_transition(complementary_data=None) + + result = processor(transition) + assert result == transition # Should return unchanged + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_missing_task_key(mock_auto_tokenizer): + """Test handling when task key is missing.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + transition = create_transition(complementary_data={"other_field": "some value"}) + + result = processor(transition) + assert result == transition # Should return unchanged + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_task_value(mock_auto_tokenizer): + """Test handling when task value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + transition = create_transition(complementary_data={"task": None}) + + result = processor(transition) + assert result == transition # Should return unchanged + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_unsupported_task_type(mock_auto_tokenizer): + """Test handling of unsupported task types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + # Test with integer task + transition = create_transition(complementary_data={"task": 123}) + + result = processor(transition) + assert result == transition # Should return unchanged + + # Test with mixed list + transition = create_transition(complementary_data={"task": ["text", 123, "more text"]}) + + result = processor(transition) + assert result == transition # Should return unchanged + + +def test_no_tokenizer_error(): + """Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided.""" + with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"): + TokenizerProcessor() + + +def test_invalid_tokenizer_name_error(): + """Test that error is raised when invalid tokenizer_name is provided.""" + with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer: + # Mock import error + mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found") + + with pytest.raises(Exception, match="Model not found"): + TokenizerProcessor(tokenizer_name="invalid-tokenizer") + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_get_config_with_tokenizer_name(mock_auto_tokenizer): + """Test configuration serialization when using tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor( + tokenizer_name="test-tokenizer", + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + expected = { + "tokenizer_name": "test-tokenizer", + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + + +def test_get_config_with_tokenizer_object(): + """Test configuration serialization when using tokenizer object.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessor( + tokenizer=mock_tokenizer, + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + # tokenizer_name should not be in config when tokenizer object is used + expected = { + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + assert "tokenizer_name" not in config + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_state_dict_methods(mock_auto_tokenizer): + """Test state_dict and load_state_dict methods.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + # Should return empty dict + state = processor.state_dict() + assert state == {} + + # load_state_dict should not raise error + processor.load_state_dict({}) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_reset_method(mock_auto_tokenizer): + """Test reset method.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + # Should not raise error + processor.reset() + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_integration_with_robot_processor(mock_auto_tokenizer): + """Test integration with RobotProcessor.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = RobotProcessor([tokenizer_processor]) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and tokenization was applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (6,) + assert attention_mask.shape == (6,) + + # Check that other data is preserved + assert torch.equal( + result[TransitionKey.OBSERVATION]["state"], transition[TransitionKey.OBSERVATION]["state"] + ) + assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION]) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): + """Test saving and loading processor with tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + original_processor = TokenizerProcessor( + tokenizer_name="test-tokenizer", max_length=32, task_key="instruction" + ) + + robot_processor = RobotProcessor([original_processor]) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor - tokenizer will be recreated from saved config + loaded_processor = RobotProcessor.from_pretrained(temp_dir) + + # Test that loaded processor works + transition = create_transition(complementary_data={"instruction": "test instruction"}) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +def test_save_and_load_pretrained_with_tokenizer_object(): + """Test saving and loading processor with tokenizer object using overrides.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction") + + robot_processor = RobotProcessor([original_processor]) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor with tokenizer override (since tokenizer object wasn't saved) + loaded_processor = RobotProcessor.from_pretrained( + temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}} + ) + + # Test that loaded processor works + transition = create_transition(complementary_data={"instruction": "test instruction"}) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +def test_registry_functionality(): + """Test that the processor is properly registered.""" + from lerobot.processor.pipeline import ProcessorStepRegistry + + # Check that the processor is registered + assert "tokenizer_processor" in ProcessorStepRegistry.list() + + # Check that we can retrieve it + retrieved_class = ProcessorStepRegistry.get("tokenizer_processor") + assert retrieved_class is TokenizerProcessor + + +def test_feature_contract_basic(): + """Test basic feature contract functionality.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128) + + input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + + output_features = processor.feature_contract(input_features) + + # Check that original features are preserved + assert "observation.state" in output_features + assert "action" in output_features + + # Check that tokenized features are added + assert f"{OBS_LANGUAGE}.tokens" in output_features + assert f"{OBS_LANGUAGE}.attention_mask" in output_features + + # Check feature properties + tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens_feature.type == FeatureType.LANGUAGE + assert tokens_feature.shape == (128,) + assert attention_mask_feature.type == FeatureType.LANGUAGE + assert attention_mask_feature.shape == (128,) + + +def test_feature_contract_with_custom_max_length(): + """Test feature contract with custom max_length.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64) + + input_features = {} + output_features = processor.feature_contract(input_features) + + # Check that features use correct max_length + assert f"{OBS_LANGUAGE}.tokens" in output_features + assert f"{OBS_LANGUAGE}.attention_mask" in output_features + + tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens_feature.shape == (64,) + assert attention_mask_feature.shape == (64,) + + +def test_feature_contract_existing_features(): + """Test feature contract when tokenized features already exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256) + + input_features = { + f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + } + + output_features = processor.feature_contract(input_features) + + # Should not overwrite existing features + assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved + assert output_features[f"{OBS_LANGUAGE}.attention_mask"].shape == (100,) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_tokenization_parameters(mock_auto_tokenizer): + """Test that tokenization parameters are correctly passed to tokenizer.""" + + # Create a custom mock that tracks calls + class TrackingMockTokenizer: + def __init__(self): + self.last_call_args = None + self.last_call_kwargs = None + + def __call__(self, *args, **kwargs): + self.last_call_args = args + self.last_call_kwargs = kwargs + # Return minimal valid output + return { + "input_ids": torch.zeros(16, dtype=torch.long), + "attention_mask": torch.ones(16, dtype=torch.long), + } + + tracking_tokenizer = TrackingMockTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + processor = TokenizerProcessor( + tokenizer_name="test-tokenizer", + max_length=16, + padding="longest", + truncation=False, + padding_side="left", + ) + + transition = create_transition(complementary_data={"task": "test task"}) + + processor(transition) + + # Check that parameters were passed correctly (task is converted to list) + assert tracking_tokenizer.last_call_args == (["test task"],) + assert tracking_tokenizer.last_call_kwargs["max_length"] == 16 + assert tracking_tokenizer.last_call_kwargs["padding"] == "longest" + assert tracking_tokenizer.last_call_kwargs["padding_side"] == "left" + assert tracking_tokenizer.last_call_kwargs["truncation"] is False + assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt" + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_preserves_other_complementary_data(mock_auto_tokenizer): + """Test that other complementary data fields are preserved.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer") + + transition = create_transition( + complementary_data={ + "task": "test task", + "episode_id": 123, + "timestamp": 456.789, + "other_field": {"nested": "data"}, + } + ) + + result = processor(transition) + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check that all original fields are preserved + assert comp_data["task"] == "test task" + assert comp_data["episode_id"] == 123 + assert comp_data["timestamp"] == 456.789 + assert comp_data["other_field"] == {"nested": "data"} + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_deterministic_tokenization(mock_auto_tokenizer): + """Test that tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition(complementary_data={"task": "consistent test"}) + + result1 = processor(transition) + result2 = processor(transition) + + tokens1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + tokens2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + # Results should be identical + assert torch.equal(tokens1, tokens2) + assert torch.equal(attention_mask1, attention_mask2) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_empty_string_task(mock_auto_tokenizer): + """Test handling of empty string task.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition(complementary_data={"task": ""}) + + result = processor(transition) + + # Should still tokenize (mock tokenizer handles empty strings) + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (8,) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_very_long_task(mock_auto_tokenizer): + """Test handling of very long task strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True) + + long_task = " ".join(["word"] * 100) # Very long task + transition = create_transition(complementary_data={"task": long_task}) + + result = processor(transition) + + # Should be truncated to max_length + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (5,) + assert attention_mask.shape == (5,) + + +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_padding_side(mock_auto_tokenizer): + """Test using custom padding_side parameter.""" + + # Create a mock tokenizer that tracks padding_side calls + class PaddingSideTrackingTokenizer: + def __init__(self): + self.padding_side_calls = [] + + def __call__( + self, + text, + max_length=512, + truncation=True, + padding="max_length", + padding_side="right", + return_tensors="pt", + **kwargs, + ): + self.padding_side_calls.append(padding_side) + # Return minimal valid output + return { + "input_ids": torch.zeros(max_length, dtype=torch.long), + "attention_mask": torch.ones(max_length, dtype=torch.long), + } + + tracking_tokenizer = PaddingSideTrackingTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + # Test left padding + processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left") + + transition = create_transition(complementary_data={"task": "test task"}) + processor_left(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "left" + + # Test right padding (default) + processor_right = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="right") + + processor_right(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "right"