diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 2b949d5cb..b9b9c6c43 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -32,7 +32,6 @@ from __future__ import annotations import importlib import json -import os import re from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence @@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + _serialized_state_filenames: tuple[str | None, ...] | None = field( + default=None, + init=False, + repr=False, + ) def __call__(self, data: TInput) -> TOutput: """Processes input data through the full pipeline. @@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): transition = processor_step(transition) yield transition - def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal method to comply with `HubMixin`'s saving mechanism. + def _get_sanitized_name(self) -> str: + """Return a filename-safe version of the pipeline name. - This method does the actual saving work and is called by HubMixin.save_pretrained. + Returns: + The lower-cased pipeline name with non-alphanumeric characters replaced by underscores. """ - config_filename = kwargs.pop("config_filename", None) + return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) - # Sanitize the pipeline name to create a valid filename prefix. - sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + @staticmethod + def _get_state_filename( + *, + step_index: int, + registry_name: str | None, + sanitized_name: str, + ) -> str: + """Return the safetensors filename for one stateful processor step. - if config_filename is None: - config_filename = f"{sanitized_name}.json" + Args: + step_index: The index of the processor step in this pipeline. + registry_name: The registered processor step name, if available. + sanitized_name: The filename-safe pipeline name. - config: dict[str, Any] = { + Returns: + The state filename used by the existing disk serialization format. + """ + if registry_name: + return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" + + return f"{sanitized_name}_step_{step_index}.safetensors" + + @staticmethod + def _get_state_key(state_filename: str) -> str: + """Return the in-memory state key for a serialized state filename. + + Args: + state_filename: The `.safetensors` filename from the serialized config. + + Returns: + The state key used by the in-memory pipeline state dictionary. + """ + return state_filename.removesuffix(".safetensors") + + @staticmethod + def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]: + """Return serialized state filenames in step order. + + Args: + loaded_config: A validated processor pipeline config. + + Returns: + A tuple containing each step's serialized state filename, or None for stateless steps. + """ + return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"]) + + def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]: + """Return expected state filenames in step order for `load_state_dict()`. + + Returns: + The preserved serialized state filenames when available, otherwise filenames derived from + current non-empty step state. + """ + if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len( + self.steps + ): + return self._serialized_state_filenames + + sanitized_name = self._get_sanitized_name() + state_filenames: list[str | None] = [] + + for step_index, processor_step in enumerate(self.steps): + step_state_dict = processor_step.state_dict() + if not step_state_dict: + state_filenames.append(None) + continue + + registry_name = getattr(processor_step.__class__, "_registry_name", None) + state_filenames.append( + self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) + ) + + return tuple(state_filenames) + + def get_config(self) -> dict[str, Any]: + """Return the JSON-serializable pipeline configuration. + + Returns: + A dictionary with the same content that `save_pretrained()` writes as JSON. + """ + sanitized_name = self._get_sanitized_name() + pipeline_config: dict[str, Any] = { "name": self.name, "steps": [], } - # Iterate through each step to build its configuration entry. for step_index, processor_step in enumerate(self.steps): registry_name = getattr(processor_step.__class__, "_registry_name", None) - step_entry: dict[str, Any] = {} - # Prefer registry name for portability, otherwise fall back to full class path. + if registry_name: step_entry["registry_name"] = registry_name else: @@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" ) - # Save step configuration if `get_config` is implemented. - if hasattr(processor_step, "get_config"): - step_entry["config"] = processor_step.get_config() + step_entry["config"] = processor_step.get_config() - # Save step state if `state_dict` is implemented and returns a non-empty dict. - if hasattr(processor_step, "state_dict"): - state = processor_step.state_dict() - if state: - # Clone tensors to avoid modifying the original state. - cloned_state = {key: tensor.clone() for key, tensor in state.items()} + step_state_dict = processor_step.state_dict() + if step_state_dict: + step_entry["state_file"] = self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) - # Create a unique filename for the state file. - if registry_name: - state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" - else: - state_filename = f"{sanitized_name}_step_{step_index}.safetensors" + pipeline_config["steps"].append(step_entry) - save_file(cloned_state, os.path.join(str(save_directory), state_filename)) - step_entry["state_file"] = state_filename + return pipeline_config - config["steps"].append(step_entry) + def state_dict(self) -> dict[str, dict[str, torch.Tensor]]: + """Return pipeline state tensors grouped by state key. - # Write the main configuration JSON file. - with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: - json.dump(config, file_pointer, indent=2) + Returns: + A dictionary mapping suffixless state keys to cloned step state dictionaries. + """ + sanitized_name = self._get_sanitized_name() + pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {} + + for step_index, processor_step in enumerate(self.steps): + step_state_dict = processor_step.state_dict() + if not step_state_dict: + continue + + registry_name = getattr(processor_step.__class__, "_registry_name", None) + state_filename = self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) + state_key = self._get_state_key(state_filename) + pipeline_state_dict[state_key] = { + tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items() + } + + return pipeline_state_dict + + def load_state_dict( + self, + state_dict: dict[str, dict[str, torch.Tensor]], + ) -> None: + """Load pipeline state tensors into the existing steps. + + Args: + state_dict: A dictionary mapping suffixless state keys to step state dictionaries. + + Raises: + KeyError: If loading finds missing expected state or unexpected extra state. + """ + expected_state_filenames = self._get_state_filenames_for_loading() + used_state_keys: set[str] = set() + + for step_index, (processor_step, state_filename) in enumerate( + zip(self.steps, expected_state_filenames, strict=True) + ): + if state_filename is None: + continue + + state_key = self._get_state_key(state_filename) + if state_key not in state_dict: + raise KeyError( + f"Missing state key '{state_key}' for processor step {step_index}. " + f"Available state keys: {sorted(state_dict.keys())}" + ) + + processor_step.load_state_dict(state_dict[state_key]) + used_state_keys.add(state_key) + + unexpected_state_keys = set(state_dict) - used_state_keys + if unexpected_state_keys: + expected_state_key_set = { + self._get_state_key(state_filename) + for state_filename in expected_state_filenames + if state_filename is not None + } + raise KeyError( + f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. " + f"Expected state keys: {sorted(expected_state_key_set)}" + ) + + def _save_pretrained(self, save_directory: Path, **kwargs) -> None: + """Internal method to comply with `HubMixin`'s saving mechanism. + + This method does the actual saving work and is called by HubMixin.save_pretrained. + """ + config_filename = kwargs.pop("config_filename", None) + sanitized_name = self._get_sanitized_name() + + if config_filename is None: + config_filename = f"{sanitized_name}.json" + + pipeline_config = self.get_config() + pipeline_state_dict = self.state_dict() + + for state_key, step_state_dict in pipeline_state_dict.items(): + state_filename = f"{state_key}.safetensors" + save_file(step_state_dict, save_directory / state_filename) + + with open(save_directory / config_filename, "w") as file_pointer: + json.dump(pipeline_config, file_pointer, indent=2) def save_pretrained( self, @@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): cls._validate_overrides_used(validated_overrides, loaded_config) # 5. Construct and return the final pipeline instance - return cls( + pipeline = cls( steps=steps, name=loaded_config.get("name", "DataProcessorPipeline"), to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), ) + pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config) + return pipeline + + @classmethod + def from_config( + cls, + config: dict[str, Any], + *, + state_dict: dict[str, dict[str, torch.Tensor]] | None = None, + overrides: dict[str, Any] | None = None, + to_transition: Callable[[TInput], EnvTransition] | None = None, + to_output: Callable[[EnvTransition], TOutput] | None = None, + ) -> DataProcessorPipeline[TInput, TOutput]: + """Build a pipeline from an in-memory config and optional state tensors. + + Args: + config: A config dictionary with the same structure as the saved processor JSON. + state_dict: Optional in-memory pipeline state grouped by suffixless state key. + overrides: Optional constructor overrides keyed by registry name or class name. + to_transition: Optional converter from input data to `EnvTransition`. + to_output: Optional converter from `EnvTransition` to output data. + + Returns: + A processor pipeline built from the config and optional state. + """ + cls._validate_loaded_config("", config, "") + + steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {}) + cls._validate_overrides_used(remaining_override_keys, config) + + pipeline = cls( + steps=steps, + name=config.get("name", "DataProcessorPipeline"), + to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), + to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), + ) + pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config) + + if state_dict is not None: + pipeline.load_state_dict(state_dict) + + return pipeline @classmethod def _load_config( @@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): ) from e @classmethod - def _validate_loaded_config( - cls, model_id: str, loaded_config: dict[str, Any], config_filename: str - ) -> None: + def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None: """Validate that a config was loaded and is a valid processor config. This method validates processor config format with intelligent migration detection: @@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): Args: model_id: The model identifier (used for migration detection) - loaded_config: The loaded config dictionary (guaranteed non-None) + loaded_config: The loaded config value to validate (may be non-dict) config_filename: The config filename that was loaded (for error messages) Raises: @@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): model_id, f"Config file '{config_filename}' is not a valid processor configuration", ) + loaded_config_description = ( + list(loaded_config.keys()) + if isinstance(loaded_config, dict) + else type(loaded_config).__name__ + ) raise ValueError( f"Config file '{config_filename}' is not a valid processor configuration. " - f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}" + f"Expected a config with 'steps' field, but got: {loaded_config_description}" ) @classmethod @@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): ImportError: If a step class cannot be imported or found in registry ValueError: If a step cannot be instantiated with its configuration """ - steps: list[ProcessorStep] = [] - override_keys = set(overrides.keys()) + steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides) - for step_entry in loaded_config["steps"]: - # 1. Get step class and key - step_class, step_key = cls._resolve_step_class(step_entry) - - # 2. Instantiate step with overrides - step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides) - - # 3. Load step state if available + for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True): cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs) - # 4. Track used overrides - if step_key in override_keys: - override_keys.discard(step_key) + return steps, remaining_override_keys - steps.append(step_instance) + @classmethod + def _build_steps_from_config( + cls, + loaded_config: dict[str, Any], + overrides: dict[str, Any], + ) -> tuple[list[ProcessorStep], set[str]]: + """Build processor steps from config without loading tensor state. - return steps, override_keys + Args: + loaded_config: The loaded processor configuration. + overrides: User-provided constructor overrides keyed by step key. + + Returns: + A tuple containing instantiated steps and override keys that did not match a step. + """ + processor_steps: list[ProcessorStep] = [] + remaining_override_keys = set(overrides.keys()) + + for step_entry in loaded_config["steps"]: + step_class, step_key = cls._resolve_step_class(step_entry) + processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides) + + if step_key in remaining_override_keys: + remaining_override_keys.discard(step_key) + + processor_steps.append(processor_step) + + return processor_steps, remaining_override_keys @classmethod def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]: @@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): return True @classmethod - def _is_processor_config(cls, config: dict) -> bool: + def _is_processor_config(cls, config: Any) -> bool: """Check if config follows DataProcessorPipeline format. This method validates the processor configuration structure: @@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): Returns: True if config follows valid DataProcessorPipeline format, False otherwise """ + if not isinstance(config, dict): + return False + # Must have a "steps" field with a list of step configurations if not isinstance(config.get("steps"), list): return False diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 2c41de22c..57e948279 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -24,6 +24,7 @@ from typing import Any import pytest import torch import torch.nn as nn +from safetensors.torch import load_file pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") @@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep): return features +class MockLazyTensorStateStep(ProcessorStep): + """Mock step whose tensor state is not present in constructor config.""" + + def __init__( + self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None + ): + self.name = name + self.scale = scale + self.tensor_state: torch.Tensor | None = None + + if initial_value is not None: + self.tensor_state = torch.tensor([initial_value], dtype=torch.float32) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Return the transition unchanged.""" + return transition + + def get_config(self) -> dict[str, Any]: + """Return constructor config while intentionally omitting tensor state.""" + return { + "name": self.name, + "scale": self.scale, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return tensor state only after it has been initialized or loaded.""" + if self.tensor_state is None: + return {} + + return {"tensor_state": self.tensor_state} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load tensor state.""" + self.tensor_state = state["tensor_state"].clone() + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Return features unchanged.""" + return features + + +@ProcessorStepRegistry.register("registered_lazy_tensor_state_step") +class RegisteredLazyTensorStateStep(MockLazyTensorStateStep): + """Registered lazy tensor state step for registry-based serialization tests.""" + + def test_empty_pipeline(): """Test pipeline with no steps.""" pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition) @@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state(): assert torch.allclose(loaded_step.running_mean, step.running_mean) +def test_get_config_matches_saved_json(): + """Test that in-memory config matches the config written by save_pretrained.""" + stateless_step = MockStep(name="stateless") + stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0) + pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline") + + in_memory_config = pipeline.get_config() + + assert pipeline.get_config() == in_memory_config + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + config_path = Path(tmp_dir) / "memory_pipeline.json" + with open(config_path) as file_pointer: + saved_config = json.load(file_pointer) + + assert in_memory_config == saved_config + assert "state_file" not in in_memory_config["steps"][0] + assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors" + + +def test_state_dict_matches_saved_safetensors(): + """Test that in-memory state matches the safetensors written by save_pretrained.""" + stateful_step = MockLazyTensorStateStep(initial_value=7.0) + pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline") + + in_memory_state_dict = pipeline.state_dict() + state_filename = "stateful_pipeline_step_0.safetensors" + state_key = "stateful_pipeline_step_0" + + assert set(in_memory_state_dict) == {state_key} + assert set(in_memory_state_dict[state_key]) == {"tensor_state"} + + in_memory_state_dict[state_key]["tensor_state"].add_(1) + assert stateful_step.tensor_state is not None + assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0])) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + saved_state_dict = load_file(Path(tmp_dir) / state_filename) + + torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0])) + + +def test_save_pretrained_still_writes_expected_serialization_files(): + """Test that save_pretrained keeps the existing config and state filenames.""" + stateful_step = MockLazyTensorStateStep(initial_value=3.0) + pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + save_path = Path(tmp_dir) + assert (save_path / "policy_preprocessor.json").exists() + assert (save_path / "policy_preprocessor_step_0.safetensors").exists() + + +def test_from_config_round_trips_stateful_pipeline(): + """Test that from_config rebuilds a stateful pipeline from in-memory artifacts.""" + stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0) + pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict) + loaded_step = loaded_pipeline.steps[0] + + assert len(loaded_pipeline) == 1 + assert isinstance(loaded_step, MockLazyTensorStateStep) + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0])) + + +def test_from_config_round_trips_registered_stateful_pipeline(): + """Test that from_config resolves registry steps and loads their named tensor state.""" + stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0) + pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors" + state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step" + + assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step" + assert config["steps"][0]["state_file"] == state_filename + assert set(pipeline_state_dict) == {state_key} + + loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, RegisteredLazyTensorStateStep) + assert loaded_step.tensor_state is not None + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0])) + + +def test_from_config_preserves_state_metadata_for_empty_initial_state(): + """Test in-memory loading when rebuilt steps start without tensor state.""" + stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0) + pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config(config) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, MockLazyTensorStateStep) + assert loaded_step.state_dict() == {} + assert "state_file" not in loaded_pipeline.get_config()["steps"][0] + + loaded_pipeline.load_state_dict(pipeline_state_dict) + + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0])) + + +def test_from_config_applies_overrides_before_state_loading(): + """Test that constructor overrides and tensor state loading are separate operations.""" + stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0) + pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config( + config, + state_dict=pipeline_state_dict, + overrides={"MockLazyTensorStateStep": {"scale": 5.0}}, + ) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, MockLazyTensorStateStep) + assert loaded_step.scale == 5.0 + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0])) + + +def test_load_state_dict_raises_on_missing_expected_state(): + """Test loading raises when serialized config expects missing state.""" + stateful_step = MockLazyTensorStateStep(initial_value=19.0) + pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline") + loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config()) + + with pytest.raises(KeyError, match="missing_pipeline_step_0"): + loaded_pipeline.load_state_dict({}) + + +def test_load_state_dict_raises_on_unexpected_extra_state(): + """Test loading raises on unexpected top-level state keys.""" + pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline") + + with pytest.raises(KeyError, match="extra"): + pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}}) + + +def test_stateless_pipeline_in_memory_serialization_returns_empty_state(): + """Test stateless in-memory serialization and loading.""" + pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline") + config = pipeline.get_config() + config_without_name = {"steps": config["steps"]} + + assert pipeline.state_dict() == {} + assert all("state_file" not in step_entry for step_entry in config["steps"]) + + loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={}) + + assert loaded_pipeline.name == "DataProcessorPipeline" + assert loaded_pipeline.state_dict() == {} + + +@pytest.mark.parametrize("invalid_config", [None, [], "not config"]) +def test_from_config_rejects_non_dict_config(invalid_config): + """Test from_config reports invalid top-level config values cleanly.""" + with pytest.raises(ValueError, match="not a valid processor configuration"): + DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type] + + class MockModuleStep(ProcessorStep, nn.Module): """Mock step that inherits from nn.Module to test state_dict handling of module parameters."""