From 332ca4ccc53bfacf294dc00e8f6932d3976a7701 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 4 Sep 2025 16:22:03 +0200 Subject: [PATCH] refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862) - Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability. --- src/lerobot/processor/pipeline.py | 7 +-- tests/processor/test_pi0_processor.py | 28 ++++++++++-- tests/processor/test_pipeline.py | 56 +++++++++++++---------- tests/processor/test_smolvla_processor.py | 32 +++++++++++-- 4 files changed, 85 insertions(+), 38 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index c56197dff..5f0a9be46 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -731,11 +731,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): def __post_init__(self): for i, step in enumerate(self.steps): - if not callable(step): - # TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated - raise TypeError( - f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" - ) + if not isinstance(step, ProcessorStep): + raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep") def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """ diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index 589bf209f..82db59921 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -27,13 +27,31 @@ from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, + EnvTransition, NormalizerProcessorStep, + ProcessorStep, RenameProcessorStep, TransitionKey, UnnormalizerProcessorStep, ) +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + def create_transition(observation=None, action=None, **kwargs): """Helper function to create a transition dictionary.""" transition = {} @@ -83,7 +101,7 @@ def test_make_pi0_processor_basic(): config = create_default_config() stats = create_default_stats() - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"): + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): preprocessor, postprocessor = make_pi0_pre_post_processors( config, stats, @@ -165,7 +183,7 @@ def test_pi0_processor_cuda(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -329,7 +347,7 @@ def test_pi0_processor_without_stats(): config = create_default_config() # Mock the tokenizer processor - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"): + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): preprocessor, postprocessor = make_pi0_pre_post_processors( config, dataset_stats=None, diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index d7ab7d6a0..f4a4c6e44 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -27,7 +27,13 @@ import torch.nn as nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features -from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor import ( + DataProcessorPipeline, + EnvTransition, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) from tests.conftest import assert_contract_is_typed @@ -47,7 +53,7 @@ def create_transition( @dataclass -class MockStep: +class MockStep(ProcessorStep): """Mock pipeline step for testing - demonstrates best practices. This example shows the proper separation: @@ -96,7 +102,7 @@ class MockStep: @dataclass -class MockStepWithoutOptionalMethods: +class MockStepWithoutOptionalMethods(ProcessorStep): """Mock step that only implements the required __call__ method.""" multiplier: float = 2.0 @@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods: @dataclass -class MockStepWithTensorState: +class MockStepWithTensorState(ProcessorStep): """Mock step demonstrating mixed JSON attributes and tensor state.""" name: str = "tensor_step" @@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state(): assert torch.allclose(loaded_step.running_mean, step.running_mean) -class MockModuleStep(nn.Module): +class MockModuleStep(ProcessorStep, nn.Module): """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" def __init__(self, input_dim: int = 10, hidden_dim: int = 5): @@ -653,12 +659,12 @@ class MockModuleStep(nn.Module): def state_dict(self) -> dict[str, torch.Tensor]: """Override to return all module parameters and buffers.""" # Get the module's state dict (includes all parameters and buffers) - return super().state_dict() + return nn.Module.state_dict(self) def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Override to load all module parameters and buffers.""" # Use the module's load_state_dict - super().load_state_dict(state) + nn.Module.load_state_dict(self, state) def reset(self) -> None: self.running_mean.zero_() @@ -669,7 +675,7 @@ class MockModuleStep(nn.Module): return features -class MockNonModuleStepWithState: +class MockNonModuleStepWithState(ProcessorStep): """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. This tests the state_dict/load_state_dict path for regular classes. @@ -753,7 +759,7 @@ class MockNonModuleStepWithState: # Tests for overrides functionality @dataclass -class MockStepWithNonSerializableParam: +class MockStepWithNonSerializableParam(ProcessorStep): """Mock step that requires a non-serializable parameter.""" def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): @@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam: @ProcessorStepRegistry.register("registered_mock_step") @dataclass -class RegisteredMockStep: +class RegisteredMockStep(ProcessorStep): """Mock step registered in the registry.""" value: int = 42 @@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry(): # Register a test step @ProcessorStepRegistry.register("test_stateful_step") @dataclass - class TestStatefulStep: + class TestStatefulStep(ProcessorStep): value: int = 0 def __init__(self, value: int = 0): @@ -1436,7 +1442,7 @@ def test_override_with_nested_config(): @ProcessorStepRegistry.register("complex_config_step") @dataclass - class ComplexConfigStep: + class ComplexConfigStep(ProcessorStep): name: str = "complex" simple_param: int = 42 nested_config: dict = None @@ -1532,7 +1538,7 @@ def test_override_with_callables(): @ProcessorStepRegistry.register("callable_step") @dataclass - class CallableStep: + class CallableStep(ProcessorStep): name: str = "callable_step" transform_fn: Any = None @@ -1667,7 +1673,7 @@ def test_override_with_device_strings(): @ProcessorStepRegistry.register("device_aware_step") @dataclass - class DeviceAwareStep: + class DeviceAwareStep(ProcessorStep): device: str = "cpu" def __init__(self, device: str = "cpu"): @@ -1806,13 +1812,17 @@ class NonCallableStep: return features -def test_construction_rejects_step_without_call(): - with pytest.raises(TypeError, match=r"must define __call__"): +def test_construction_rejects_step_without_processorstep(): + """Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep.""" + with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"): DataProcessorPipeline([NonCallableStep()]) + with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"): + DataProcessorPipeline([NonCompliantStep()]) + @dataclass -class FeatureContractAddStep: +class FeatureContractAddStep(ProcessorStep): """Adds a PolicyFeature""" key: str = "a" @@ -1827,7 +1837,7 @@ class FeatureContractAddStep: @dataclass -class FeatureContractMutateStep: +class FeatureContractMutateStep(ProcessorStep): """Mutates a PolicyFeature""" key: str = "a" @@ -1842,7 +1852,7 @@ class FeatureContractMutateStep: @dataclass -class FeatureContractBadReturnStep: +class FeatureContractBadReturnStep(ProcessorStep): """Returns a non-dict""" def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep: @dataclass -class FeatureContractRemoveStep: +class FeatureContractRemoveStep(ProcessorStep): """Removes a PolicyFeature""" key: str @@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory): def test_features_execution_order_tracking(): - class Track: + class Track(ProcessorStep): def __init__(self, label): self.label = label @@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory): @dataclass -class AddActionEEAndJointFeatures: +class AddActionEEAndJointFeatures(ProcessorStep): """Adds both EE and JOINT action features.""" def __call__(self, tr): @@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures: @dataclass -class AddObservationStateFeatures: +class AddObservationStateFeatures(ProcessorStep): """Adds state features (and optionally an image spec to test precedence).""" add_front_image: bool = False diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index dc1b1a83f..05bccec80 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -30,13 +30,31 @@ from lerobot.policies.smolvla.processor_smolvla import ( from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, + EnvTransition, NormalizerProcessorStep, + ProcessorStep, RenameProcessorStep, TransitionKey, UnnormalizerProcessorStep, ) +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + def create_transition(observation=None, action=None, **kwargs): """Helper function to create a transition dictionary.""" transition = {} @@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic(): config = create_default_config() stats = create_default_stats() - with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"): + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, stats, @@ -170,7 +190,7 @@ def test_smolvla_processor_cuda(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -227,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -285,7 +305,7 @@ def test_smolvla_processor_multi_gpu(): stats = create_default_stats() # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep: + class MockTokenizerProcessorStep(ProcessorStep): def __init__(self, *args, **kwargs): pass @@ -340,7 +360,9 @@ def test_smolvla_processor_without_stats(): config = create_default_config() # Mock the tokenizer processor - with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"): + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, dataset_stats=None,