mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862)
- Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability.
This commit is contained in:
@@ -731,11 +731,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
for i, step in enumerate(self.steps):
|
for i, step in enumerate(self.steps):
|
||||||
if not callable(step):
|
if not isinstance(step, ProcessorStep):
|
||||||
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
|
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
|
||||||
raise TypeError(
|
|
||||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
|
||||||
)
|
|
||||||
|
|
||||||
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -27,13 +27,31 @@ from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre
|
|||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
|
EnvTransition,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
|
ProcessorStep,
|
||||||
RenameProcessorStep,
|
RenameProcessorStep,
|
||||||
TransitionKey,
|
TransitionKey,
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
|
"""Mock tokenizer processor step for testing."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
# Pass through transition unchanged
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
# Pass through features unchanged
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def create_transition(observation=None, action=None, **kwargs):
|
def create_transition(observation=None, action=None, **kwargs):
|
||||||
"""Helper function to create a transition dictionary."""
|
"""Helper function to create a transition dictionary."""
|
||||||
transition = {}
|
transition = {}
|
||||||
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
|
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
stats,
|
stats,
|
||||||
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
|
|
||||||
# Mock the tokenizer processor
|
# Mock the tokenizer processor
|
||||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep"):
|
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
|
|||||||
@@ -27,7 +27,13 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||||
from lerobot.processor import DataProcessorPipeline, EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor import (
|
||||||
|
DataProcessorPipeline,
|
||||||
|
EnvTransition,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
TransitionKey,
|
||||||
|
)
|
||||||
from tests.conftest import assert_contract_is_typed
|
from tests.conftest import assert_contract_is_typed
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +53,7 @@ def create_transition(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStep:
|
class MockStep(ProcessorStep):
|
||||||
"""Mock pipeline step for testing - demonstrates best practices.
|
"""Mock pipeline step for testing - demonstrates best practices.
|
||||||
|
|
||||||
This example shows the proper separation:
|
This example shows the proper separation:
|
||||||
@@ -96,7 +102,7 @@ class MockStep:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStepWithoutOptionalMethods:
|
class MockStepWithoutOptionalMethods(ProcessorStep):
|
||||||
"""Mock step that only implements the required __call__ method."""
|
"""Mock step that only implements the required __call__ method."""
|
||||||
|
|
||||||
multiplier: float = 2.0
|
multiplier: float = 2.0
|
||||||
@@ -118,7 +124,7 @@ class MockStepWithoutOptionalMethods:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStepWithTensorState:
|
class MockStepWithTensorState(ProcessorStep):
|
||||||
"""Mock step demonstrating mixed JSON attributes and tensor state."""
|
"""Mock step demonstrating mixed JSON attributes and tensor state."""
|
||||||
|
|
||||||
name: str = "tensor_step"
|
name: str = "tensor_step"
|
||||||
@@ -613,7 +619,7 @@ def test_mixed_json_and_tensor_state():
|
|||||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||||
|
|
||||||
|
|
||||||
class MockModuleStep(nn.Module):
|
class MockModuleStep(ProcessorStep, nn.Module):
|
||||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||||
|
|
||||||
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
|
def __init__(self, input_dim: int = 10, hidden_dim: int = 5):
|
||||||
@@ -653,12 +659,12 @@ class MockModuleStep(nn.Module):
|
|||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
"""Override to return all module parameters and buffers."""
|
"""Override to return all module parameters and buffers."""
|
||||||
# Get the module's state dict (includes all parameters and buffers)
|
# Get the module's state dict (includes all parameters and buffers)
|
||||||
return super().state_dict()
|
return nn.Module.state_dict(self)
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
"""Override to load all module parameters and buffers."""
|
"""Override to load all module parameters and buffers."""
|
||||||
# Use the module's load_state_dict
|
# Use the module's load_state_dict
|
||||||
super().load_state_dict(state)
|
nn.Module.load_state_dict(self, state)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.running_mean.zero_()
|
self.running_mean.zero_()
|
||||||
@@ -669,7 +675,7 @@ class MockModuleStep(nn.Module):
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
class MockNonModuleStepWithState:
|
class MockNonModuleStepWithState(ProcessorStep):
|
||||||
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
|
"""Mock step that explicitly does NOT inherit from nn.Module but has tensor state.
|
||||||
|
|
||||||
This tests the state_dict/load_state_dict path for regular classes.
|
This tests the state_dict/load_state_dict path for regular classes.
|
||||||
@@ -753,7 +759,7 @@ class MockNonModuleStepWithState:
|
|||||||
|
|
||||||
# Tests for overrides functionality
|
# Tests for overrides functionality
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockStepWithNonSerializableParam:
|
class MockStepWithNonSerializableParam(ProcessorStep):
|
||||||
"""Mock step that requires a non-serializable parameter."""
|
"""Mock step that requires a non-serializable parameter."""
|
||||||
|
|
||||||
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None):
|
||||||
@@ -808,7 +814,7 @@ class MockStepWithNonSerializableParam:
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("registered_mock_step")
|
@ProcessorStepRegistry.register("registered_mock_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class RegisteredMockStep:
|
class RegisteredMockStep(ProcessorStep):
|
||||||
"""Mock step registered in the registry."""
|
"""Mock step registered in the registry."""
|
||||||
|
|
||||||
value: int = 42
|
value: int = 42
|
||||||
@@ -1381,7 +1387,7 @@ def test_state_file_naming_with_registry():
|
|||||||
# Register a test step
|
# Register a test step
|
||||||
@ProcessorStepRegistry.register("test_stateful_step")
|
@ProcessorStepRegistry.register("test_stateful_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestStatefulStep:
|
class TestStatefulStep(ProcessorStep):
|
||||||
value: int = 0
|
value: int = 0
|
||||||
|
|
||||||
def __init__(self, value: int = 0):
|
def __init__(self, value: int = 0):
|
||||||
@@ -1436,7 +1442,7 @@ def test_override_with_nested_config():
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("complex_config_step")
|
@ProcessorStepRegistry.register("complex_config_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComplexConfigStep:
|
class ComplexConfigStep(ProcessorStep):
|
||||||
name: str = "complex"
|
name: str = "complex"
|
||||||
simple_param: int = 42
|
simple_param: int = 42
|
||||||
nested_config: dict = None
|
nested_config: dict = None
|
||||||
@@ -1532,7 +1538,7 @@ def test_override_with_callables():
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("callable_step")
|
@ProcessorStepRegistry.register("callable_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class CallableStep:
|
class CallableStep(ProcessorStep):
|
||||||
name: str = "callable_step"
|
name: str = "callable_step"
|
||||||
transform_fn: Any = None
|
transform_fn: Any = None
|
||||||
|
|
||||||
@@ -1667,7 +1673,7 @@ def test_override_with_device_strings():
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("device_aware_step")
|
@ProcessorStepRegistry.register("device_aware_step")
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeviceAwareStep:
|
class DeviceAwareStep(ProcessorStep):
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
|
||||||
def __init__(self, device: str = "cpu"):
|
def __init__(self, device: str = "cpu"):
|
||||||
@@ -1806,13 +1812,17 @@ class NonCallableStep:
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
def test_construction_rejects_step_without_call():
|
def test_construction_rejects_step_without_processorstep():
|
||||||
with pytest.raises(TypeError, match=r"must define __call__"):
|
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||||
|
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||||
DataProcessorPipeline([NonCallableStep()])
|
DataProcessorPipeline([NonCallableStep()])
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||||
|
DataProcessorPipeline([NonCompliantStep()])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeatureContractAddStep:
|
class FeatureContractAddStep(ProcessorStep):
|
||||||
"""Adds a PolicyFeature"""
|
"""Adds a PolicyFeature"""
|
||||||
|
|
||||||
key: str = "a"
|
key: str = "a"
|
||||||
@@ -1827,7 +1837,7 @@ class FeatureContractAddStep:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeatureContractMutateStep:
|
class FeatureContractMutateStep(ProcessorStep):
|
||||||
"""Mutates a PolicyFeature"""
|
"""Mutates a PolicyFeature"""
|
||||||
|
|
||||||
key: str = "a"
|
key: str = "a"
|
||||||
@@ -1842,7 +1852,7 @@ class FeatureContractMutateStep:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeatureContractBadReturnStep:
|
class FeatureContractBadReturnStep(ProcessorStep):
|
||||||
"""Returns a non-dict"""
|
"""Returns a non-dict"""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -1853,7 +1863,7 @@ class FeatureContractBadReturnStep:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeatureContractRemoveStep:
|
class FeatureContractRemoveStep(ProcessorStep):
|
||||||
"""Removes a PolicyFeature"""
|
"""Removes a PolicyFeature"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
@@ -1906,7 +1916,7 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
|
|||||||
|
|
||||||
|
|
||||||
def test_features_execution_order_tracking():
|
def test_features_execution_order_tracking():
|
||||||
class Track:
|
class Track(ProcessorStep):
|
||||||
def __init__(self, label):
|
def __init__(self, label):
|
||||||
self.label = label
|
self.label = label
|
||||||
|
|
||||||
@@ -1945,7 +1955,7 @@ def test_features_remove_from_initial(policy_feature_factory):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddActionEEAndJointFeatures:
|
class AddActionEEAndJointFeatures(ProcessorStep):
|
||||||
"""Adds both EE and JOINT action features."""
|
"""Adds both EE and JOINT action features."""
|
||||||
|
|
||||||
def __call__(self, tr):
|
def __call__(self, tr):
|
||||||
@@ -1962,7 +1972,7 @@ class AddActionEEAndJointFeatures:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AddObservationStateFeatures:
|
class AddObservationStateFeatures(ProcessorStep):
|
||||||
"""Adds state features (and optionally an image spec to test precedence)."""
|
"""Adds state features (and optionally an image spec to test precedence)."""
|
||||||
|
|
||||||
add_front_image: bool = False
|
add_front_image: bool = False
|
||||||
|
|||||||
@@ -30,13 +30,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
|
|||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
|
EnvTransition,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
|
ProcessorStep,
|
||||||
RenameProcessorStep,
|
RenameProcessorStep,
|
||||||
TransitionKey,
|
TransitionKey,
|
||||||
UnnormalizerProcessorStep,
|
UnnormalizerProcessorStep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
|
"""Mock tokenizer processor step for testing."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
# Pass through transition unchanged
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
# Pass through features unchanged
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
def create_transition(observation=None, action=None, **kwargs):
|
def create_transition(observation=None, action=None, **kwargs):
|
||||||
"""Helper function to create a transition dictionary."""
|
"""Helper function to create a transition dictionary."""
|
||||||
transition = {}
|
transition = {}
|
||||||
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
|
with patch(
|
||||||
|
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||||
|
):
|
||||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
stats,
|
stats,
|
||||||
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -227,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -285,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
|
|||||||
stats = create_default_stats()
|
stats = create_default_stats()
|
||||||
|
|
||||||
# Mock the tokenizer processor to act as pass-through
|
# Mock the tokenizer processor to act as pass-through
|
||||||
class MockTokenizerProcessorStep:
|
class MockTokenizerProcessorStep(ProcessorStep):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -340,7 +360,9 @@ def test_smolvla_processor_without_stats():
|
|||||||
config = create_default_config()
|
config = create_default_config()
|
||||||
|
|
||||||
# Mock the tokenizer processor
|
# Mock the tokenizer processor
|
||||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep"):
|
with patch(
|
||||||
|
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||||
|
):
|
||||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
dataset_stats=None,
|
dataset_stats=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user