refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)

- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity.
- Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests.
- Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
This commit is contained in:
Adil Zouitine
2025-09-04 16:22:03 +02:00
committed by GitHub
parent fc43246942
commit 332ca4ccc5
4 changed files with 85 additions and 38 deletions
+2 -5
View File
@@ -731,11 +731,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
def __post_init__(self): def __post_init__(self):
for i, step in enumerate(self.steps): for i, step in enumerate(self.steps):
if not callable(step): if not isinstance(step, ProcessorStep):
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
""" """
+23 -5
View File
@@ -27,13 +27,31 @@ from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre
from lerobot.processor import ( from lerobot.processor import (
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep, NormalizerProcessorStep,
ProcessorStep,
RenameProcessorStep, RenameProcessorStep,
TransitionKey, TransitionKey,
UnnormalizerProcessorStep, UnnormalizerProcessorStep,
) )
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs): def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary.""" """Helper function to create a transition dictionary."""
transition = {} transition = {}
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
config = create_default_config() config = create_default_config()
stats = create_default_stats() stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"): with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors( preprocessor, postprocessor = make_pi0_pre_post_processors(
config, config,
stats, stats,
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
config = create_default_config() config = create_default_config()
# Mock the tokenizer processor # Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"): with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors( preprocessor, postprocessor = make_pi0_pre_post_processors(
config, config,
dataset_stats=None, dataset_stats=None,
+33 -23
View File
@@ -27,7 +27,13 @@ import torch.nn as nn
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor import (
DataProcessorPipeline,
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from tests.conftest import assert_contract_is_typed from tests.conftest import assert_contract_is_typed
@@ -47,7 +53,7 @@ def create_transition(
@dataclass @dataclass
class MockStep: class MockStep(ProcessorStep):
"""Mock pipeline step for testing - demonstrates best practices. """Mock pipeline step for testing - demonstrates best practices.
This example shows the proper separation: This example shows the proper separation:
@@ -96,7 +102,7 @@ class MockStep:
@dataclass @dataclass
class MockStepWithoutOptionalMethods: class MockStepWithoutOptionalMethods(ProcessorStep):
"""Mock step that only implements the required __call__ method.""" """Mock step that only implements the required __call__ method."""
multiplier: float = 2.0 multiplier: float = 2.0
@@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods:
@dataclass @dataclass
class MockStepWithTensorState: class MockStepWithTensorState(ProcessorStep):
"""Mock step demonstrating mixed JSON attributes and tensor state.""" """Mock step demonstrating mixed JSON attributes and tensor state."""
name: str = "tensor_step" name: str = "tensor_step"
@@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean) assert torch.allclose(loaded_step.running_mean, step.running_mean)
class MockModuleStep(nn.Module): class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" """Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
def __init__(self, input_dim: int = 10, hidden_dim: int = 5): def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
@@ -653,12 +659,12 @@ class MockModuleStep(nn.Module):
def state_dict(self) -> dict[str, torch.Tensor]: def state_dict(self) -> dict[str, torch.Tensor]:
"""Override to return all module parameters and buffers.""" """Override to return all module parameters and buffers."""
# Get the module's state dict (includes all parameters and buffers) # Get the module's state dict (includes all parameters and buffers)
return super().state_dict() return nn.Module.state_dict(self)
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Override to load all module parameters and buffers.""" """Override to load all module parameters and buffers."""
# Use the module's load_state_dict # Use the module's load_state_dict
super().load_state_dict(state) nn.Module.load_state_dict(self, state)
def reset(self) -> None: def reset(self) -> None:
self.running_mean.zero_() self.running_mean.zero_()
@@ -669,7 +675,7 @@ class MockModuleStep(nn.Module):
return features return features
class MockNonModuleStepWithState: class MockNonModuleStepWithState(ProcessorStep):
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state. """Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
This tests the state_dict/load_state_dict path for regular classes. This tests the state_dict/load_state_dict path for regular classes.
@@ -753,7 +759,7 @@ class MockNonModuleStepWithState:
# Tests for overrides functionality # Tests for overrides functionality
@dataclass @dataclass
class MockStepWithNonSerializableParam: class MockStepWithNonSerializableParam(ProcessorStep):
"""Mock step that requires a non-serializable parameter.""" """Mock step that requires a non-serializable parameter."""
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
@@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam:
@ProcessorStepRegistry.register("registered_mock_step") @ProcessorStepRegistry.register("registered_mock_step")
@dataclass @dataclass
class RegisteredMockStep: class RegisteredMockStep(ProcessorStep):
"""Mock step registered in the registry.""" """Mock step registered in the registry."""
value: int = 42 value: int = 42
@@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry():
# Register a test step # Register a test step
@ProcessorStepRegistry.register("test_stateful_step") @ProcessorStepRegistry.register("test_stateful_step")
@dataclass @dataclass
class TestStatefulStep: class TestStatefulStep(ProcessorStep):
value: int = 0 value: int = 0
def __init__(self, value: int = 0): def __init__(self, value: int = 0):
@@ -1436,7 +1442,7 @@ def test_override_with_nested_config():
@ProcessorStepRegistry.register("complex_config_step") @ProcessorStepRegistry.register("complex_config_step")
@dataclass @dataclass
class ComplexConfigStep: class ComplexConfigStep(ProcessorStep):
name: str = "complex" name: str = "complex"
simple_param: int = 42 simple_param: int = 42
nested_config: dict = None nested_config: dict = None
@@ -1532,7 +1538,7 @@ def test_override_with_callables():
@ProcessorStepRegistry.register("callable_step") @ProcessorStepRegistry.register("callable_step")
@dataclass @dataclass
class CallableStep: class CallableStep(ProcessorStep):
name: str = "callable_step" name: str = "callable_step"
transform_fn: Any = None transform_fn: Any = None
@@ -1667,7 +1673,7 @@ def test_override_with_device_strings():
@ProcessorStepRegistry.register("device_aware_step") @ProcessorStepRegistry.register("device_aware_step")
@dataclass @dataclass
class DeviceAwareStep: class DeviceAwareStep(ProcessorStep):
device: str = "cpu" device: str = "cpu"
def __init__(self, device: str = "cpu"): def __init__(self, device: str = "cpu"):
@@ -1806,13 +1812,17 @@ class NonCallableStep:
return features return features
def test_construction_rejects_step_without_call(): def test_construction_rejects_step_without_processorstep():
with pytest.raises(TypeError, match=r"must define __call__"): """Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
DataProcessorPipeline([NonCallableStep()]) DataProcessorPipeline([NonCallableStep()])
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
DataProcessorPipeline([NonCompliantStep()])
@dataclass @dataclass
class FeatureContractAddStep: class FeatureContractAddStep(ProcessorStep):
"""Adds a PolicyFeature""" """Adds a PolicyFeature"""
key: str = "a" key: str = "a"
@@ -1827,7 +1837,7 @@ class FeatureContractAddStep:
@dataclass @dataclass
class FeatureContractMutateStep: class FeatureContractMutateStep(ProcessorStep):
"""Mutates a PolicyFeature""" """Mutates a PolicyFeature"""
key: str = "a" key: str = "a"
@@ -1842,7 +1852,7 @@ class FeatureContractMutateStep:
@dataclass @dataclass
class FeatureContractBadReturnStep: class FeatureContractBadReturnStep(ProcessorStep):
"""Returns a non-dict""" """Returns a non-dict"""
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep:
@dataclass @dataclass
class FeatureContractRemoveStep: class FeatureContractRemoveStep(ProcessorStep):
"""Removes a PolicyFeature""" """Removes a PolicyFeature"""
key: str key: str
@@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
def test_features_execution_order_tracking(): def test_features_execution_order_tracking():
class Track: class Track(ProcessorStep):
def __init__(self, label): def __init__(self, label):
self.label = label self.label = label
@@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory):
@dataclass @dataclass
class AddActionEEAndJointFeatures: class AddActionEEAndJointFeatures(ProcessorStep):
"""Adds both EE and JOINT action features.""" """Adds both EE and JOINT action features."""
def __call__(self, tr): def __call__(self, tr):
@@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures:
@dataclass @dataclass
class AddObservationStateFeatures: class AddObservationStateFeatures(ProcessorStep):
"""Adds state features (and optionally an image spec to test precedence).""" """Adds state features (and optionally an image spec to test precedence)."""
add_front_image: bool = False add_front_image: bool = False
+27 -5
View File
@@ -30,13 +30,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
from lerobot.processor import ( from lerobot.processor import (
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep, NormalizerProcessorStep,
ProcessorStep,
RenameProcessorStep, RenameProcessorStep,
TransitionKey, TransitionKey,
UnnormalizerProcessorStep, UnnormalizerProcessorStep,
) )
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_transition(observation=None, action=None, **kwargs): def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary.""" """Helper function to create a transition dictionary."""
transition = {} transition = {}
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
config = create_default_config() config = create_default_config()
stats = create_default_stats() stats = create_default_stats()
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"): with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors( preprocessor, postprocessor = make_smolvla_pre_post_processors(
config, config,
stats, stats,
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -227,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -285,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
stats = create_default_stats() stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through # Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep: class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@@ -340,7 +360,9 @@ def test_smolvla_processor_without_stats():
config = create_default_config() config = create_default_config()
# Mock the tokenizer processor # Mock the tokenizer processor
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"): with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, postprocessor = make_smolvla_pre_post_processors( preprocessor, postprocessor = make_smolvla_pre_post_processors(
config, config,
dataset_stats=None, dataset_stats=None,