mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)
- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
This commit is contained in:
@@ -731,11 +731,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
|
||||
|
||||
def __post_init__(self):
|
||||
for i, step in enumerate(self.steps):
|
||||
if not callable(step):
|
||||
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
if not isinstance(step, ProcessorStep):
|
||||
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
|
||||
|
||||
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
|
||||
@@ -27,13 +27,31 @@ from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
|
||||
@@ -27,7 +27,13 @@ import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -47,7 +53,7 @@ def create_transition(
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStep:
|
||||
class MockStep(ProcessorStep):
|
||||
"""Mock pipeline step for testing - demonstrates best practices.
|
||||
|
||||
This example shows the proper separation:
|
||||
@@ -96,7 +102,7 @@ class MockStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithoutOptionalMethods:
|
||||
class MockStepWithoutOptionalMethods(ProcessorStep):
|
||||
"""Mock step that only implements the required __call__ method."""
|
||||
|
||||
multiplier: float = 2.0
|
||||
@@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStepWithTensorState:
|
||||
class MockStepWithTensorState(ProcessorStep):
|
||||
"""Mock step demonstrating mixed JSON attributes and tensor state."""
|
||||
|
||||
name: str = "tensor_step"
|
||||
@@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
class MockModuleStep(nn.Module):
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
|
||||
@@ -653,12 +659,12 @@ class MockModuleStep(nn.Module):
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Override to return all module parameters and buffers."""
|
||||
# Get the module's state dict (includes all parameters and buffers)
|
||||
return super().state_dict()
|
||||
return nn.Module.state_dict(self)
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Override to load all module parameters and buffers."""
|
||||
# Use the module's load_state_dict
|
||||
super().load_state_dict(state)
|
||||
nn.Module.load_state_dict(self, state)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.running_mean.zero_()
|
||||
@@ -669,7 +675,7 @@ class MockModuleStep(nn.Module):
|
||||
return features
|
||||
|
||||
|
||||
class MockNonModuleStepWithState:
|
||||
class MockNonModuleStepWithState(ProcessorStep):
|
||||
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
|
||||
|
||||
This tests the state_dict/load_state_dict path for regular classes.
|
||||
@@ -753,7 +759,7 @@ class MockNonModuleStepWithState:
|
||||
|
||||
# Tests for overrides functionality
|
||||
@dataclass
|
||||
class MockStepWithNonSerializableParam:
|
||||
class MockStepWithNonSerializableParam(ProcessorStep):
|
||||
"""Mock step that requires a non-serializable parameter."""
|
||||
|
||||
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
||||
@@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam:
|
||||
|
||||
@ProcessorStepRegistry.register("registered_mock_step")
|
||||
@dataclass
|
||||
class RegisteredMockStep:
|
||||
class RegisteredMockStep(ProcessorStep):
|
||||
"""Mock step registered in the registry."""
|
||||
|
||||
value: int = 42
|
||||
@@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry():
|
||||
# Register a test step
|
||||
@ProcessorStepRegistry.register("test_stateful_step")
|
||||
@dataclass
|
||||
class TestStatefulStep:
|
||||
class TestStatefulStep(ProcessorStep):
|
||||
value: int = 0
|
||||
|
||||
def __init__(self, value: int = 0):
|
||||
@@ -1436,7 +1442,7 @@ def test_override_with_nested_config():
|
||||
|
||||
@ProcessorStepRegistry.register("complex_config_step")
|
||||
@dataclass
|
||||
class ComplexConfigStep:
|
||||
class ComplexConfigStep(ProcessorStep):
|
||||
name: str = "complex"
|
||||
simple_param: int = 42
|
||||
nested_config: dict = None
|
||||
@@ -1532,7 +1538,7 @@ def test_override_with_callables():
|
||||
|
||||
@ProcessorStepRegistry.register("callable_step")
|
||||
@dataclass
|
||||
class CallableStep:
|
||||
class CallableStep(ProcessorStep):
|
||||
name: str = "callable_step"
|
||||
transform_fn: Any = None
|
||||
|
||||
@@ -1667,7 +1673,7 @@ def test_override_with_device_strings():
|
||||
|
||||
@ProcessorStepRegistry.register("device_aware_step")
|
||||
@dataclass
|
||||
class DeviceAwareStep:
|
||||
class DeviceAwareStep(ProcessorStep):
|
||||
device: str = "cpu"
|
||||
|
||||
def __init__(self, device: str = "cpu"):
|
||||
@@ -1806,13 +1812,17 @@ class NonCallableStep:
|
||||
return features
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_call():
|
||||
with pytest.raises(TypeError, match=r"must define __call__"):
|
||||
def test_construction_rejects_step_without_processorstep():
|
||||
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
DataProcessorPipeline([NonCallableStep()])
|
||||
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
DataProcessorPipeline([NonCompliantStep()])
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractAddStep:
|
||||
class FeatureContractAddStep(ProcessorStep):
|
||||
"""Adds a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
@@ -1827,7 +1837,7 @@ class FeatureContractAddStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractMutateStep:
|
||||
class FeatureContractMutateStep(ProcessorStep):
|
||||
"""Mutates a PolicyFeature"""
|
||||
|
||||
key: str = "a"
|
||||
@@ -1842,7 +1852,7 @@ class FeatureContractMutateStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractBadReturnStep:
|
||||
class FeatureContractBadReturnStep(ProcessorStep):
|
||||
"""Returns a non-dict"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep:
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureContractRemoveStep:
|
||||
class FeatureContractRemoveStep(ProcessorStep):
|
||||
"""Removes a PolicyFeature"""
|
||||
|
||||
key: str
|
||||
@@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
|
||||
|
||||
|
||||
def test_features_execution_order_tracking():
|
||||
class Track:
|
||||
class Track(ProcessorStep):
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
|
||||
@@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddActionEEAndJointFeatures:
|
||||
class AddActionEEAndJointFeatures(ProcessorStep):
|
||||
"""Adds both EE and JOINT action features."""
|
||||
|
||||
def __call__(self, tr):
|
||||
@@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddObservationStateFeatures:
|
||||
class AddObservationStateFeatures(ProcessorStep):
|
||||
"""Adds state features (and optionally an image spec to test precedence)."""
|
||||
|
||||
add_front_image: bool = False
|
||||
|
||||
@@ -30,13 +30,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -227,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -285,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -340,7 +360,9 @@ def test_smolvla_processor_without_stats():
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
|
||||
Reference in New Issue
Block a user