refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)

- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity.
- Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests.
- Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
This commit is contained in:
Adil Zouitine
2025-09-04 16:22:03 +02:00
committed by GitHub
parent fc43246942
commit 332ca4ccc5
4 changed files with 85 additions and 38 deletions
+2 -5
View File
@@ -731,11 +731,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
if not isinstance(step, ProcessorStep):
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
+23 -5
View File
@@ -27,13 +27,31 @@ from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
+33 -23
View File
@@ -27,7 +27,13 @@ import torch.nn as nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor import (
DataProcessorPipeline,
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from tests.conftest import assert_contract_is_typed
@@ -47,7 +53,7 @@ def create_transition(
@dataclass
class MockStep:
class MockStep(ProcessorStep):
"""Mock pipeline step for testing - demonstrates best practices.
This example shows the proper separation:
@@ -96,7 +102,7 @@ class MockStep:
@dataclass
class MockStepWithoutOptionalMethods:
class MockStepWithoutOptionalMethods(ProcessorStep):
"""Mock step that only implements the required __call__ method."""
multiplier: float = 2.0
@@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods:
@dataclass
class MockStepWithTensorState:
class MockStepWithTensorState(ProcessorStep):
"""Mock step demonstrating mixed JSON attributes and tensor state."""
name: str = "tensor_step"
@@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
class MockModuleStep(nn.Module):
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
@@ -653,12 +659,12 @@ class MockModuleStep(nn.Module):
def state_dict(self) -> dict[str, torch.Tensor]:
"""Override to return all module parameters and buffers."""
# Get the module's state dict (includes all parameters and buffers)
return super().state_dict()
return nn.Module.state_dict(self)
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Override to load all module parameters and buffers."""
# Use the module's load_state_dict
super().load_state_dict(state)
nn.Module.load_state_dict(self, state)
def reset(self) -> None:
self.running_mean.zero_()
@@ -669,7 +675,7 @@ class MockModuleStep(nn.Module):
return features
class MockNonModuleStepWithState:
class MockNonModuleStepWithState(ProcessorStep):
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
This tests the state_dict/load_state_dict path for regular classes.
@@ -753,7 +759,7 @@ class MockNonModuleStepWithState:
# Tests for overrides functionality
@dataclass
class MockStepWithNonSerializableParam:
class MockStepWithNonSerializableParam(ProcessorStep):
"""Mock step that requires a non-serializable parameter."""
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
@@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam:
@ProcessorStepRegistry.register("registered_mock_step")
@dataclass
class RegisteredMockStep:
class RegisteredMockStep(ProcessorStep):
"""Mock step registered in the registry."""
value: int = 42
@@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry():
# Register a test step
@ProcessorStepRegistry.register("test_stateful_step")
@dataclass
class TestStatefulStep:
class TestStatefulStep(ProcessorStep):
value: int = 0
def __init__(self, value: int = 0):
@@ -1436,7 +1442,7 @@ def test_override_with_nested_config():
@ProcessorStepRegistry.register("complex_config_step")
@dataclass
class ComplexConfigStep:
class ComplexConfigStep(ProcessorStep):
name: str = "complex"
simple_param: int = 42
nested_config: dict = None
@@ -1532,7 +1538,7 @@ def test_override_with_callables():
@ProcessorStepRegistry.register("callable_step")
@dataclass
class CallableStep:
class CallableStep(ProcessorStep):
name: str = "callable_step"
transform_fn: Any = None
@@ -1667,7 +1673,7 @@ def test_override_with_device_strings():
@ProcessorStepRegistry.register("device_aware_step")
@dataclass
class DeviceAwareStep:
class DeviceAwareStep(ProcessorStep):
device: str = "cpu"
def __init__(self, device: str = "cpu"):
@@ -1806,13 +1812,17 @@ class NonCallableStep:
return features
def test_construction_rejects_step_without_call():
with pytest.raises(TypeError, match=r"must define __call__"):
def test_construction_rejects_step_without_processorstep():
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
DataProcessorPipeline([NonCallableStep()])
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
DataProcessorPipeline([NonCompliantStep()])
@dataclass
class FeatureContractAddStep:
class FeatureContractAddStep(ProcessorStep):
"""Adds a PolicyFeature"""
key: str = "a"
@@ -1827,7 +1837,7 @@ class FeatureContractAddStep:
@dataclass
class FeatureContractMutateStep:
class FeatureContractMutateStep(ProcessorStep):
"""Mutates a PolicyFeature"""
key: str = "a"
@@ -1842,7 +1852,7 @@ class FeatureContractMutateStep:
@dataclass
class FeatureContractBadReturnStep:
class FeatureContractBadReturnStep(ProcessorStep):
"""Returns a non-dict"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep:
@dataclass
class FeatureContractRemoveStep:
class FeatureContractRemoveStep(ProcessorStep):
"""Removes a PolicyFeature"""
key: str
@@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
def test_features_execution_order_tracking():
class Track:
class Track(ProcessorStep):
def __init__(self, label):
self.label = label
@@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory):
@dataclass
class AddActionEEAndJointFeatures:
class AddActionEEAndJointFeatures(ProcessorStep):
"""Adds both EE and JOINT action features."""
def __call__(self, tr):
@@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures:
@dataclass
class AddObservationStateFeatures:
class AddObservationStateFeatures(ProcessorStep):
"""Adds state features (and optionally an image spec to test precedence)."""
add_front_image: bool = False
+27 -5
View File
@@ -30,13 +30,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
stats,
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -227,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -285,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep:
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
@@ -340,7 +360,9 @@ def test_smolvla_processor_without_stats():
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors(
config,
dataset_stats=None,