feat(processor): Add in-memory processor pipeline serialization (#3732)

* feat(processor): add in-memory pipeline serialization

Expose processor pipeline config and tensor state without requiring temporary files, so processors can be transported, compared, or hashed directly in memory.

* feat(processor): enhance DataProcessorPipeline with registry support

- Added a new RegisteredLazyTensorStateStep for registry-based serialization tests.
- Improved state filename handling in _get_state_filename method.
- Refactored validation logic in _validate_loaded_config to simplify parameter types.
- Updated tests to verify registry step functionality and ensure correct state loading.

* refactor(processor): update state handling in DataProcessorPipeline

- Introduced a new static method _get_state_key to derive in-memory state keys from serialized filenames.
- Updated state_dict and load_state_dict methods to use suffixless state keys instead of filenames.
- Adjusted related tests to reflect changes in state key handling, ensuring consistency in state management

* fix(processor): update loaded_config argument description in DataProcessorPipeline

- Clarified the documentation for the loaded_config parameter to indicate that it may be a non-dictionary value, enhancing understanding for future developers.
This commit is contained in:
Adil Zouitine
2026-06-08 11:27:24 +02:00
committed by GitHub
parent 09808183ca
commit 49755a3d9e
2 changed files with 499 additions and 55 deletions
+279 -55
View File
@@ -32,7 +32,6 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
This method does the actual saving work and is called by HubMixin.save_pretrained.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
"""
config_filename = kwargs.pop("config_filename", None)
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
config: dict[str, Any] = {
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_entry["config"] = processor_step.get_config()
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
pipeline_config["steps"].append(step_entry)
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
return pipeline_config
config["steps"].append(step_entry)
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
return cls(
pipeline = cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config dictionary (guaranteed non-None)
loaded_config: The loaded config value to validate (may be non-dict)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
)
@classmethod
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
# 3. Load step state if available
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
return steps, remaining_override_keys
steps.append(step_instance)
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
return steps, override_keys
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
processor_steps.append(processor_step)
return processor_steps, remaining_override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: dict) -> bool:
def _is_processor_config(cls, config: Any) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
+220
View File
@@ -24,6 +24,7 @@ from typing import Any
import pytest
import torch
import torch.nn as nn
from safetensors.torch import load_file
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
return features
class MockLazyTensorStateStep(ProcessorStep):
"""Mock step whose tensor state is not present in constructor config."""
def __init__(
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
):
self.name = name
self.scale = scale
self.tensor_state: torch.Tensor | None = None
if initial_value is not None:
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Return the transition unchanged."""
return transition
def get_config(self) -> dict[str, Any]:
"""Return constructor config while intentionally omitting tensor state."""
return {
"name": self.name,
"scale": self.scale,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only after it has been initialized or loaded."""
if self.tensor_state is None:
return {}
return {"tensor_state": self.tensor_state}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load tensor state."""
self.tensor_state = state["tensor_state"].clone()
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Return features unchanged."""
return features
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
"""Registered lazy tensor state step for registry-based serialization tests."""
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
def test_get_config_matches_saved_json():
"""Test that in-memory config matches the config written by save_pretrained."""
stateless_step = MockStep(name="stateless")
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
in_memory_config = pipeline.get_config()
assert pipeline.get_config() == in_memory_config
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
config_path = Path(tmp_dir) / "memory_pipeline.json"
with open(config_path) as file_pointer:
saved_config = json.load(file_pointer)
assert in_memory_config == saved_config
assert "state_file" not in in_memory_config["steps"][0]
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
def test_state_dict_matches_saved_safetensors():
"""Test that in-memory state matches the safetensors written by save_pretrained."""
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
in_memory_state_dict = pipeline.state_dict()
state_filename = "stateful_pipeline_step_0.safetensors"
state_key = "stateful_pipeline_step_0"
assert set(in_memory_state_dict) == {state_key}
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
in_memory_state_dict[state_key]["tensor_state"].add_(1)
assert stateful_step.tensor_state is not None
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
def test_save_pretrained_still_writes_expected_serialization_files():
"""Test that save_pretrained keeps the existing config and state filenames."""
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
save_path = Path(tmp_dir)
assert (save_path / "policy_preprocessor.json").exists()
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
def test_from_config_round_trips_stateful_pipeline():
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert len(loaded_pipeline) == 1
assert isinstance(loaded_step, MockLazyTensorStateStep)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
def test_from_config_round_trips_registered_stateful_pipeline():
"""Test that from_config resolves registry steps and loads their named tensor state."""
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
assert config["steps"][0]["state_file"] == state_filename
assert set(pipeline_state_dict) == {state_key}
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
assert loaded_step.tensor_state is not None
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
def test_from_config_preserves_state_metadata_for_empty_initial_state():
"""Test in-memory loading when rebuilt steps start without tensor state."""
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.state_dict() == {}
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
loaded_pipeline.load_state_dict(pipeline_state_dict)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
def test_from_config_applies_overrides_before_state_loading():
"""Test that constructor overrides and tensor state loading are separate operations."""
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(
config,
state_dict=pipeline_state_dict,
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.scale == 5.0
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
def test_load_state_dict_raises_on_missing_expected_state():
"""Test loading raises when serialized config expects missing state."""
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
loaded_pipeline.load_state_dict({})
def test_load_state_dict_raises_on_unexpected_extra_state():
"""Test loading raises on unexpected top-level state keys."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
with pytest.raises(KeyError, match="extra"):
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
"""Test stateless in-memory serialization and loading."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
config = pipeline.get_config()
config_without_name = {"steps": config["steps"]}
assert pipeline.state_dict() == {}
assert all("state_file" not in step_entry for step_entry in config["steps"])
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
assert loaded_pipeline.name == "DataProcessorPipeline"
assert loaded_pipeline.state_dict() == {}
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
def test_from_config_rejects_non_dict_config(invalid_config):
"""Test from_config reports invalid top-level config values cleanly."""
with pytest.raises(ValueError, match="not a valid processor configuration"):
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""