refactor(processor): improve processor pipeline typing with generic type (#1810)

* refactor(processor): introduce generic type for to_output

- Always return `TOutput`
- Remove `_prepare_transition`, so `__call__` now always returns `TOutput`
- Update tests accordingly
- This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor

* refactor(processor): consolidate ProcessorKwargs usage across policies

- Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline.
- Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments.
- Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
Adil Zouitine
2025-09-02 12:57:14 +02:00
committed by GitHub
parent 08fb310eaa
commit d32b76cc66
26 changed files with 847 additions and 220 deletions
+22 -3
View File
@@ -20,6 +20,7 @@ from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -28,8 +29,16 @@ from lerobot.processor import (
def make_act_pre_post_processors(
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: ACTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -46,6 +55,16 @@ def make_act_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -21,6 +21,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,8 +30,16 @@ from lerobot.processor import (
def make_diffusion_pre_post_processors(
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -47,6 +56,15 @@ def make_diffusion_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+50 -11
View File
@@ -38,7 +38,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
@@ -114,6 +114,8 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
preprocessor_kwargs: ProcessorKwargs | None
postprocessor_kwargs: ProcessorKwargs | None
def make_pre_post_processors(
@@ -139,16 +141,24 @@ def make_pre_post_processors(
NotImplementedError: If the policy type doesn't have a processor implemented.
"""
if pretrained_path:
# Extract preprocessor and postprocessor kwargs
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
return (
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=preprocessor_kwargs.get("to_transition"),
to_output=preprocessor_kwargs.get("to_output"),
),
RobotProcessor.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=postprocessor_kwargs.get("to_transition"),
to_output=postprocessor_kwargs.get("to_output"),
),
)
@@ -157,61 +167,90 @@ def make_pre_post_processors(
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
processors = make_tdmpc_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, DiffusionConfig):
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
processors = make_diffusion_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, ACTConfig):
from lerobot.policies.act.processor_act import make_act_pre_post_processors
processors = make_act_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
processors = make_vqbet_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
processors = make_pi0_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, PI0Config):
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
processors = make_pi0fast_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
processors = make_sac_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"))
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
processors = make_smolvla_pre_post_processors(
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
)
else:
+21 -3
View File
@@ -22,6 +22,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
@@ -65,8 +66,16 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
def make_pi0_pre_post_processors(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
@@ -93,6 +102,15 @@ def make_pi0_pre_post_processors(
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -21,6 +21,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,8 +30,16 @@ from lerobot.processor import (
def make_pi0fast_pre_post_processors(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
@@ -47,6 +56,15 @@ def make_pi0fast_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+21 -3
View File
@@ -22,6 +22,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -30,8 +31,16 @@ from lerobot.processor import (
def make_sac_pre_post_processors(
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: SACConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -48,6 +57,15 @@ def make_sac_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -20,13 +20,22 @@ from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
ProcessorKwargs,
RobotProcessor,
)
def make_classifier_processor(
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: RewardClassifierConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
@@ -37,6 +46,16 @@ def make_classifier_processor(
DeviceProcessor(device=config.device),
]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
steps=output_steps, name="classifier_postprocessor"
return (
RobotProcessor(
steps=input_steps,
name="classifier_preprocessor",
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name="classifier_postprocessor",
**postprocessor_kwargs,
),
)
@@ -21,6 +21,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -34,8 +35,16 @@ from lerobot.processor.pipeline import (
def make_smolvla_pre_post_processors(
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: SmolVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
@@ -59,8 +68,17 @@ def make_smolvla_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+21 -3
View File
@@ -21,6 +21,7 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -29,8 +30,16 @@ from lerobot.processor import (
def make_tdmpc_pre_post_processors(
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
@@ -47,6 +56,15 @@ def make_tdmpc_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+21 -3
View File
@@ -22,6 +22,7 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
ProcessorKwargs,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
@@ -30,8 +31,16 @@ from lerobot.processor import (
def make_vqbet_pre_post_processors(
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: VQBeTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[RobotProcessor, RobotProcessor]:
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
postprocessor_kwargs = {}
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
@@ -48,6 +57,15 @@ def make_vqbet_pre_post_processors(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
return (
RobotProcessor(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
RobotProcessor(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+2
View File
@@ -37,6 +37,7 @@ from .pipeline import (
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ProcessorKwargs,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
@@ -68,6 +69,7 @@ __all__ = [
"UnnormalizerProcessor",
"hotswap_stats",
"ObservationProcessor",
"ProcessorKwargs",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
+100 -77
View File
@@ -24,7 +24,7 @@ from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, TypedDict
from typing import Any, Generic, TypedDict, TypeVar, cast
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -33,6 +33,9 @@ from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature
# Type variable for generic processor output type
TOutput = TypeVar("TOutput")
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
@@ -216,6 +219,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
metadata without breaking the processor.
"""
# Validate input type
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
@@ -279,8 +286,15 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
return batch
class ProcessorKwargs(TypedDict, total=False):
"""Keyword arguments for RobotProcessor constructor."""
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
to_output: Callable[[EnvTransition], Any] | None
@dataclass
class RobotProcessor(ModelHubMixin):
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
"""
Composable, debuggable post-processing processor for robot transitions.
@@ -288,20 +302,43 @@ class RobotProcessor(ModelHubMixin):
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
and batch dictionaries, automatically converting between formats as needed.
The processor is generic over its output type TOutput, which provides better type safety
and clarity about what the processor returns.
Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor".
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition dict to the desired output format.
Usually it is a batch dict or EnvTransition dict.
Defaults to _default_transition_to_batch.
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
Defaults to _default_transition_to_batch (returns batch dict).
Use identity function (lambda x: x) for EnvTransition output.
before_step_hooks: List of hooks called before each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
Type Safety Examples:
```python
# Default behavior - returns batch dict
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
# For EnvTransition output, explicitly specify identity function
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
steps=[some_step1, some_step2],
to_output=lambda x: x, # Identity function
)
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
# For custom output types
processor: RobotProcessor[str] = RobotProcessor(
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
)
result: str = processor(batch_data) # Type checker knows this is str
```
Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to
reorder hooks after registration without creating a new pipeline.
@@ -323,8 +360,13 @@ class RobotProcessor(ModelHubMixin):
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
)
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
default_factory=lambda: _default_transition_to_batch, repr=False
to_output: Callable[[EnvTransition], TOutput] = field(
# Cast is necessary here: Working around Python type-checker limitation.
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
# making this cast safe.
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
repr=False,
)
# Processor-level hooks for observation/monitoring
@@ -332,98 +374,57 @@ class RobotProcessor(ModelHubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]):
def __call__(self, data: dict[str, Any]) -> TOutput:
"""Process data through all steps.
The method accepts either the classic EnvTransition dict or a batch dictionary
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
it is first converted to the internal dict format using to_transition; after all
steps are executed the dict is transformed back into a batch dict with to_batch and the
result is returned thereby preserving the caller's original data type.
The method accepts a batch dictionary (like the ones returned by ReplayBuffer or
LeRobotDataset). It is first converted to EnvTransition format using to_transition,
then processed through all steps, and finally converted to the output format using to_output.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
data: A batch dictionary to process.
Returns:
The processed data in the same format as the input (EnvTransition or batch dict).
Raises:
ValueError: If the transition is not a valid EnvTransition format.
The processed data in the format specified by to_output.
"""
# Check if we need to convert back to batch format at the end
_, called_with_batch = self._prepare_transition(data)
# Always convert input through to_transition
transition = self.to_transition(data)
# Use step_through to get the iterator
step_iterator = self.step_through(data)
# Get initial state (before any steps)
current_transition = next(step_iterator)
# Process each step with hooks
for idx, next_transition in enumerate(step_iterator):
# Apply before hooks with current state (before step execution)
# Process through all steps
for idx, processor_step in enumerate(self.steps):
# Apply before hooks
for hook in self.before_step_hooks:
hook(idx, current_transition)
hook(idx, transition)
# Move to next state (after step execution)
current_transition = next_transition
# Execute step
transition = processor_step(transition)
# Apply after hooks with updated state
# Apply after hooks
for hook in self.after_step_hooks:
hook(idx, current_transition)
hook(idx, transition)
# Convert back to original format if needed
if called_with_batch or self.to_output is not _default_transition_to_batch:
return self.to_output(current_transition)
else:
return current_transition
# Always use to_output for consistent typing
return self.to_output(transition)
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Returns:
A tuple of (prepared_transition, called_with_batch_flag)
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed.
Note: This method always yields EnvTransition objects regardless of input format.
If you need the results in the original input format, you'll need to convert them
Note: This method always yields EnvTransition objects regardless of output format.
If you need the results in the output format, you'll need to convert them
using `to_output()`.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
data: A batch dictionary to process.
Yields:
The intermediate EnvTransition results after each step.
"""
transition, _ = self._prepare_transition(data)
# Always convert input through to_transition
transition = self.to_transition(data)
# Yield initial state
yield transition
@@ -525,8 +526,10 @@ class RobotProcessor(ModelHubMixin):
revision: str | None = None,
config_filename: str | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
**kwargs,
) -> RobotProcessor:
) -> RobotProcessor[TOutput]:
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
Args:
@@ -540,9 +543,14 @@ class RobotProcessor(ModelHubMixin):
(for registered steps). Values are dictionaries containing parameter overrides
that will be merged with the saved configuration. This is useful for providing
non-serializable objects like environment instances.
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition dict to the desired output format of type T.
Defaults to _default_transition_to_batch (returns batch dict).
Use identity function (lambda x: x) for EnvTransition output.
Returns:
A RobotProcessor instance loaded from the saved configuration.
A RobotProcessor[TOutput] instance loaded from the saved configuration.
Raises:
ImportError: If a processor step class cannot be loaded or imported.
@@ -756,19 +764,34 @@ class RobotProcessor(ModelHubMixin):
f"Make sure override keys match exact step class names or registry names."
)
return cls(steps, loaded_config.get("name", "RobotProcessor"))
return cls(
steps=steps,
name=loaded_config.get("name", "RobotProcessor"),
to_transition=to_transition or _default_batch_to_transition,
# Cast is necessary here: Same type-checker limitation as above.
# When to_output is None, we use the default which returns dict[str, Any].
# The cast ensures type consistency with the generic TOutput parameter.
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
)
def __len__(self) -> int:
"""Return the number of steps in the processor."""
return len(self.steps)
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
"""Indexing helper exposing underlying steps.
* ``int`` returns the idx-th ProcessorStep.
* ``slice`` returns a new RobotProcessor with the sliced steps.
"""
if isinstance(idx, slice):
return RobotProcessor(self.steps[idx], self.name)
return RobotProcessor(
steps=self.steps[idx],
name=self.name,
to_transition=self.to_transition,
to_output=self.to_output,
before_step_hooks=self.before_step_hooks.copy(),
after_step_hooks=self.after_step_hooks.copy(),
)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
+51 -10
View File
@@ -102,7 +102,12 @@ def test_act_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {OBS_STATE: torch.randn(7)}
@@ -131,7 +136,12 @@ def test_act_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {OBS_STATE: torch.randn(7)}
@@ -160,7 +170,12 @@ def test_act_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -223,14 +238,21 @@ def test_act_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(7)}
@@ -249,7 +271,12 @@ def test_act_processor_device_placement_preservation():
# Test with CPU config
config.device = "cpu"
preprocessor, _ = make_act_pre_post_processors(config, stats)
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Process CPU data
observation = {OBS_STATE: torch.randn(7)}
@@ -269,12 +296,21 @@ def test_act_processor_mixed_precision():
stats = create_default_stats()
# Modify the device processor to use float16
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
@@ -294,7 +330,12 @@ def test_act_processor_batch_consistency():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
preprocessor, postprocessor = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test single sample (unbatched)
observation = {OBS_STATE: torch.randn(7)}
+11 -5
View File
@@ -245,7 +245,7 @@ def test_mixed_observation():
def test_integration_with_robot_processor():
"""Test ToBatchProcessor integration with RobotProcessor."""
to_batch_processor = ToBatchProcessor()
pipeline = RobotProcessor([to_batch_processor])
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
# Create unbatched observation
observation = {
@@ -285,7 +285,9 @@ def test_serialization_methods():
def test_save_and_load_pretrained():
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
processor = ToBatchProcessor()
pipeline = RobotProcessor([processor], name="BatchPipeline")
pipeline = RobotProcessor(
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
@@ -296,7 +298,9 @@ def test_save_and_load_pretrained():
assert config_path.exists()
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
assert loaded_pipeline.name == "BatchPipeline"
assert len(loaded_pipeline) == 1
@@ -323,11 +327,13 @@ def test_registry_functionality():
def test_registry_based_save_load():
"""Test saving and loading using registry name."""
processor = ToBatchProcessor()
pipeline = RobotProcessor([processor])
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
# Verify the loaded processor works
observation = {
+50 -11
View File
@@ -97,7 +97,12 @@ def test_classifier_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {
@@ -123,7 +128,12 @@ def test_classifier_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -156,7 +166,12 @@ def test_classifier_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -230,14 +245,22 @@ def test_classifier_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {
@@ -260,13 +283,19 @@ def test_classifier_processor_mixed_precision():
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_classifier_processor(config, stats)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Create test data
observation = {
@@ -290,7 +319,12 @@ def test_classifier_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with batched data
batch_size = 16
@@ -315,7 +349,12 @@ def test_classifier_processor_postprocessor_identity():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_classifier_processor(config, stats)
preprocessor, postprocessor = make_classifier_processor(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data for postprocessor
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
+10 -1
View File
@@ -311,7 +311,12 @@ def test_integration_with_robot_processor():
device_processor = DeviceProcessor(device="cpu")
batch_processor = ToBatchProcessor()
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
processor = RobotProcessor(
steps=[batch_processor, device_processor],
name="test_pipeline",
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Create test data
observation = {OBS_STATE: torch.randn(10)}
@@ -985,6 +990,8 @@ def test_policy_processor_integration():
DeviceProcessor(device="cuda"),
],
name="test_preprocessor",
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Create output processor (postprocessor) that moves to CPU
@@ -994,6 +1001,8 @@ def test_policy_processor_integration():
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
],
name="test_postprocessor",
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test data on CPU
+50 -11
View File
@@ -105,7 +105,12 @@ def test_diffusion_processor_with_images():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data with images
observation = {
@@ -131,7 +136,12 @@ def test_diffusion_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -164,7 +174,12 @@ def test_diffusion_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -238,14 +253,22 @@ def test_diffusion_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {
@@ -268,13 +291,19 @@ def test_diffusion_processor_mixed_precision():
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessor):
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
else:
modified_steps.append(step)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Create test data
observation = {
@@ -298,7 +327,12 @@ def test_diffusion_processor_identity_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
image_value = torch.rand(3, 224, 224) * 255 # Large values
@@ -322,7 +356,12 @@ def test_diffusion_processor_batch_consistency():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
preprocessor, postprocessor = make_diffusion_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with different batch sizes
for batch_size in [1, 8, 32]:
+2 -2
View File
@@ -506,7 +506,7 @@ def test_get_config(full_stats):
def test_integration_with_robot_processor(normalizer_processor):
"""Test integration with RobotProcessor pipeline"""
robot_processor = RobotProcessor([normalizer_processor])
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
@@ -1317,7 +1317,7 @@ def test_hotswap_stats_functional_test():
# Create original processor
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
original_processor = RobotProcessor(steps=[normalizer])
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
# Process with original stats
original_result = original_processor(transition)
+30 -5
View File
@@ -84,7 +84,12 @@ def test_make_pi0_processor_basic():
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -183,7 +188,12 @@ def test_pi0_processor_cuda():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -233,7 +243,12 @@ def test_pi0_processor_accelerate_scenario():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -284,7 +299,12 @@ def test_pi0_processor_multi_gpu():
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -310,7 +330,12 @@ def test_pi0_processor_without_stats():
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
preprocessor, postprocessor = make_pi0_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Should still create processors
assert preprocessor is not None
+32 -12
View File
@@ -176,7 +176,7 @@ class MockStepWithTensorState:
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotProcessor()
pipeline = RobotProcessor([], to_transition=lambda x: x, to_output=lambda x: x)
transition = create_transition()
result = pipeline(transition)
@@ -188,7 +188,7 @@ def test_empty_pipeline():
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
pipeline = RobotProcessor([step], to_transition=lambda x: x, to_output=lambda x: x)
transition = create_transition()
result = pipeline(transition)
@@ -205,7 +205,7 @@ def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
pipeline = RobotProcessor([step1, step2], to_transition=lambda x: x, to_output=lambda x: x)
transition = create_transition()
result = pipeline(transition)
@@ -557,7 +557,9 @@ def test_save_and_load_pretrained():
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotProcessor([step])
pipeline = RobotProcessor(
[step], to_transition=lambda x: x, to_output=lambda x: x
) # Identity for EnvTransition input/output
transition = create_transition(reward=2.0)
result = pipeline(transition)
@@ -878,7 +880,9 @@ def test_from_pretrained_with_overrides():
"registered_mock_step": {"device": "cuda", "value": 200},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
# Verify the pipeline was loaded correctly
assert len(loaded_pipeline) == 2
@@ -914,7 +918,9 @@ def test_from_pretrained_with_partial_overrides():
# The current implementation applies overrides to ALL steps with the same class name
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
@@ -971,7 +977,9 @@ def test_from_pretrained_registered_step_override():
# Override using registry name
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that overrides were applied
transition = create_transition()
@@ -999,7 +1007,9 @@ def test_from_pretrained_mixed_registered_and_unregistered():
"registered_mock_step": {"value": 777},
}
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
)
# Test both steps
transition = create_transition(reward=2.0)
@@ -1020,7 +1030,9 @@ def test_from_pretrained_no_overrides():
pipeline.save_pretrained(tmp_dir)
# Load without overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
assert len(loaded_pipeline) == 1
@@ -1040,7 +1052,9 @@ def test_from_pretrained_empty_overrides():
pipeline.save_pretrained(tmp_dir)
# Load with empty overrides
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, overrides={}, to_transition=lambda x: x, to_output=lambda x: x
)
assert len(loaded_pipeline) == 1
@@ -1455,6 +1469,8 @@ def test_override_with_nested_config():
loaded = RobotProcessor.from_pretrained(
tmp_dir,
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test that override worked
@@ -1553,7 +1569,10 @@ def test_override_with_callables():
# Load with callable override
loaded = RobotProcessor.from_pretrained(
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
tmp_dir,
overrides={"callable_step": {"transform_fn": double_values}},
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test it works
@@ -1857,7 +1876,8 @@ def test_save_load_with_custom_converter_functions():
# Should work with standard format (wouldn't work with custom converter)
result = loaded(batch)
assert "observation.image" in result # Standard format preserved
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
assert "observation.image" in result
class NonCompliantStep:
+6 -4
View File
@@ -188,7 +188,7 @@ def test_integration_with_robot_processor():
}
rename_processor = RenameProcessor(rename_map=rename_map)
pipeline = RobotProcessor([rename_processor])
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
observation = {
"agent_pos": np.array([1.0, 2.0, 3.0]),
@@ -236,7 +236,9 @@ def test_save_and_load_pretrained():
assert len(state_files) == 0
# Load pipeline
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
loaded_pipeline = RobotProcessor.from_pretrained(
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
assert loaded_pipeline.name == "TestRenameProcessor"
assert len(loaded_pipeline) == 1
@@ -277,7 +279,7 @@ def test_registry_functionality():
def test_registry_based_save_load():
"""Test save/load using registry name instead of module path."""
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
pipeline = RobotProcessor([processor])
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save and load
@@ -318,7 +320,7 @@ def test_chained_rename_processors():
}
)
pipeline = RobotProcessor([processor1, processor2])
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
observation = {
"pos": np.array([1.0, 2.0]),
+73 -11
View File
@@ -78,7 +78,12 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -102,7 +107,12 @@ def test_sac_processor_normalization_modes():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
@@ -133,7 +143,12 @@ def test_sac_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {OBS_STATE: torch.randn(10)}
@@ -162,7 +177,12 @@ def test_sac_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -185,7 +205,12 @@ def test_sac_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -205,7 +230,22 @@ def test_sac_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Should still create processors
assert preprocessor is not None
@@ -225,14 +265,21 @@ def test_sac_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(10)}
@@ -252,7 +299,12 @@ def test_sac_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -277,7 +329,12 @@ def test_sac_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with batched data
batch_size = 32
@@ -298,7 +355,12 @@ def test_sac_processor_edge_cases():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
preprocessor, postprocessor = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with empty observation
transition = create_transition(observation={}, action=torch.randn(5))
+30 -5
View File
@@ -89,7 +89,12 @@ def test_make_smolvla_processor_basic():
stats = create_default_stats()
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -188,7 +193,12 @@ def test_smolvla_processor_cuda():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -238,7 +248,12 @@ def test_smolvla_processor_accelerate_scenario():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -289,7 +304,12 @@ def test_smolvla_processor_multi_gpu():
return features
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -315,7 +335,12 @@ def test_smolvla_processor_without_stats():
# Mock the tokenizer processor
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
dataset_stats=None,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Should still create processors
assert preprocessor is not None
+73 -11
View File
@@ -81,7 +81,12 @@ def test_make_tdmpc_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -105,7 +110,12 @@ def test_tdmpc_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {
@@ -138,7 +148,12 @@ def test_tdmpc_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -171,7 +186,12 @@ def test_tdmpc_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -198,7 +218,12 @@ def test_tdmpc_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -222,7 +247,22 @@ def test_tdmpc_processor_without_stats():
"""Test TDMPC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Should still create processors
assert preprocessor is not None
@@ -245,14 +285,21 @@ def test_tdmpc_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {
@@ -276,7 +323,12 @@ def test_tdmpc_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -305,7 +357,12 @@ def test_tdmpc_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with batched data
batch_size = 64
@@ -330,7 +387,12 @@ def test_tdmpc_processor_edge_cases():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with only state observation (no image)
observation = {OBS_STATE: torch.randn(12)}
+13 -6
View File
@@ -389,7 +389,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = RobotProcessor([tokenizer_processor])
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
@@ -427,14 +427,16 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
)
robot_processor = RobotProcessor([original_processor])
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor - tokenizer will be recreated from saved config
loaded_processor = RobotProcessor.from_pretrained(temp_dir)
loaded_processor = RobotProcessor.from_pretrained(
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
transition = create_transition(
@@ -456,7 +458,7 @@ def test_save_and_load_pretrained_with_tokenizer_object():
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
robot_processor = RobotProcessor([original_processor])
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
@@ -464,7 +466,10 @@ def test_save_and_load_pretrained_with_tokenizer_object():
# Load processor with tokenizer override (since tokenizer object wasn't saved)
loaded_processor = RobotProcessor.from_pretrained(
temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}
temp_dir,
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Test that loaded processor works
@@ -952,7 +957,9 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
# Create pipeline with TokenizerProcessor then DeviceProcessor
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
device_processor = DeviceProcessor(device="cuda:0")
robot_processor = RobotProcessor([tokenizer_processor, device_processor])
robot_processor = RobotProcessor(
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
)
# Start with CPU tensors
transition = create_transition(
+73 -11
View File
@@ -81,7 +81,12 @@ def test_make_vqbet_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -105,7 +110,12 @@ def test_vqbet_processor_with_images():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data with images and states
observation = {
@@ -131,7 +141,12 @@ def test_vqbet_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -164,7 +179,12 @@ def test_vqbet_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
@@ -191,7 +211,12 @@ def test_vqbet_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -215,7 +240,22 @@ def test_vqbet_processor_without_stats():
"""Test VQBeT processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Should still create processors
assert preprocessor is not None
@@ -238,14 +278,21 @@ def test_vqbet_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {
@@ -269,7 +316,12 @@ def test_vqbet_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -298,7 +350,12 @@ def test_vqbet_processor_large_batch():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with large batch
batch_size = 128
@@ -323,7 +380,12 @@ def test_vqbet_processor_sequential_processing():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
preprocessor, postprocessor = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Process multiple samples sequentially
results = []