From 2c4e888c7f2c12d02f762df5b97ee51a253eb45f Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 31 Jul 2025 16:29:48 +0200 Subject: [PATCH] Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops --- src/lerobot/processor/device_processor.py | 4 + src/lerobot/processor/normalize_processor.py | 6 + .../processor/observation_processor.py | 46 ++++ src/lerobot/processor/pipeline.py | 39 ++++ src/lerobot/processor/rename_processor.py | 8 + tests/conftest.py | 17 ++ tests/processor/test_observation_processor.py | 79 +++++++ tests/processor/test_pipeline.py | 214 ++++++++++++++++++ tests/processor/test_rename_processor.py | 59 +++++ 9 files changed, 472 insertions(+) diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 232454850..8d7d04878 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -18,6 +18,7 @@ from typing import Any import torch +from lerobot.configs.types import PolicyFeature from lerobot.processor.pipeline import EnvTransition, TransitionKey @@ -74,3 +75,6 @@ class DeviceProcessor: def get_config(self) -> dict[str, Any]: """Return configuration for serialization.""" return {"device": self.device} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 70c4f764f..a8424013c 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -204,6 +204,9 @@ class NormalizerProcessor: def reset(self): pass + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + @dataclass @ProcessorStepRegistry.register(name="unnormalizer_processor") @@ -327,3 +330,6 @@ class UnnormalizerProcessor: def reset(self): pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index bee33f434..091b1286d 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,6 +21,7 @@ import numpy as np import torch from torch import Tensor +from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @@ -110,6 +111,27 @@ class ImageProcessor: """Reset processor state (no-op for this processor).""" pass + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + pixels -> OBS_IMAGE, + observation.pixels -> OBS_IMAGE, + pixels. -> OBS_IMAGES., + observation.pixels. -> OBS_IMAGES. + """ + if "pixels" in features: + features[OBS_IMAGE] = features.pop("pixels") + if "observation.pixels" in features: + features[OBS_IMAGE] = features.pop("observation.pixels") + + prefixes = ("pixels.", "observation.pixels.") + for key in list(features.keys()): + for p in prefixes: + if key.startswith(p): + suffix = key[len(p) :] + features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key) + break + return features + @dataclass class StateProcessor: @@ -169,6 +191,25 @@ class StateProcessor: """Reset processor state (no-op for this processor).""" pass + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + environment_state -> OBS_ENV_STATE, + agent_pos -> OBS_STATE, + observation.environment_state -> OBS_ENV_STATE, + observation.agent_pos -> OBS_STATE + """ + pairs = ( + ("environment_state", OBS_ENV_STATE), + ("agent_pos", OBS_STATE), + ) + for old, new in pairs: + if old in features: + features[new] = features.pop(old) + prefixed = f"observation.{old}" + if prefixed in features: + features[new] = features.pop(prefixed) + return features + @dataclass @ProcessorStepRegistry.register(name="observation_processor") @@ -219,3 +260,8 @@ class VanillaObservationProcessor: """Reset processor state.""" self.image_processor.reset() self.state_processor.reset() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features = self.image_processor.feature_contract(features) + features = self.state_processor.feature_contract(features) + return features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 1ecae4892..f945f367b 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -19,6 +19,7 @@ import importlib import json import os from collections.abc import Callable, Iterable, Sequence +from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path @@ -29,6 +30,7 @@ from huggingface_hub import ModelHubMixin, hf_hub_download from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_file, save_file +from lerobot.configs.types import PolicyFeature from lerobot.utils.utils import get_safe_torch_device @@ -141,6 +143,11 @@ class ProcessorStep(Protocol): automatically serialise the step's configuration and learnable state using a safe-to-share JSON + SafeTensors format. + + **Required**: + - ``__call__(transition: EnvTransition) -> EnvTransition`` + - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + Optional helper protocol: * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable configuration and state. YOU decide what to save here. This is where all @@ -168,6 +175,8 @@ class ProcessorStep(Protocol): def reset(self) -> None: ... + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 """Convert a *batch* dict coming from Learobot replay/dataset code into an @@ -840,6 +849,33 @@ class RobotProcessor(ModelHubMixin): return f"RobotProcessor({', '.join(parts)})" + def __post_init__(self): + for i, step in enumerate(self.steps): + if not callable(step): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" + ) + + fc = getattr(step, "feature_contract", None) + if not callable(fc): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" + ) + + def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Apply ALL steps in order. Each step must implement + feature_contract(features) and return a dict (full or incremental schema). + """ + features: dict[str, PolicyFeature] = deepcopy(initial_features) + + for _, step in enumerate(self.steps): + out = step.feature_contract(features) + if not isinstance(out, dict): + raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + features = out + return features + class ObservationProcessor: """Base class for processors that modify only the observation component of a transition. @@ -1145,3 +1181,6 @@ class IdentityProcessor: def reset(self) -> None: pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 08855e237..7e1897541 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -18,6 +18,7 @@ from typing import Any import torch +from lerobot.configs.types import PolicyFeature from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @@ -53,3 +54,10 @@ class RenameProcessor: def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + - Each key in the observation that appears in `rename_map` is renamed to its value. + - Keys not in `rename_map` remain unchanged. + """ + return {self.rename_map.get(k, k): v for k, v in features.items()} diff --git a/tests/conftest.py b/tests/conftest.py index 69dd3049b..7940cc5ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ import traceback import pytest from serial import SerialException +from lerobot.configs.types import FeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +@pytest.fixture +def policy_feature_factory(): + """PolicyFeature factory""" + + def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature: + return PolicyFeature(type=ft, shape=shape) + + return _pf + + +def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: + assert isinstance(features, dict) + assert all(isinstance(k, str) for k in features.keys()) + assert all(isinstance(v, PolicyFeature) for v in features.values()) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 5026a9177..fb6a78155 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -18,12 +18,15 @@ import numpy as np import pytest import torch +from lerobot.configs.types import FeatureType +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor import ( ImageProcessor, StateProcessor, VanillaObservationProcessor, ) from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed def create_transition( @@ -420,3 +423,79 @@ def test_equivalent_with_image_dict(): for key in original_result: torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): + processor = ImageProcessor() + features = { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] + assert "pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): + processor = ImageProcessor() + features = { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] + assert "observation.pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): + processor = ImageProcessor() + features = { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + } + out = processor.feature_contract(features.copy()) + + assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] + assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] + assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] + assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): + processor = StateProcessor() + features = { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): + proc = StateProcessor() + features = { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + } + out = proc.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert_contract_is_typed(out) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 68405648c..8c12e9167 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -16,6 +16,7 @@ import json import tempfile +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from typing import Any @@ -25,8 +26,10 @@ import pytest import torch import torch.nn as nn +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed def create_transition( @@ -88,6 +91,10 @@ class MockStep: def reset(self) -> None: self.counter = 0 + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + @dataclass class MockStepWithoutOptionalMethods: @@ -106,6 +113,10 @@ class MockStepWithoutOptionalMethods: return transition + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + @dataclass class MockStepWithTensorState: @@ -158,6 +169,10 @@ class MockStepWithTensorState: self.running_mean.zero_() self.running_count.zero_() + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + def test_empty_pipeline(): """Test pipeline with no steps.""" @@ -699,6 +714,10 @@ class MockModuleStep(nn.Module): self.running_mean.zero_() self.counter = 0 + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + def test_to_device_with_state_dict(): """Test moving pipeline to device for steps with state_dict.""" @@ -953,6 +972,10 @@ class MockNonModuleStepWithState: self.step_count.zero_() self.history.clear() + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + def test_to_device_non_module_class(): """Test moving pipeline to device for regular classes (non nn.Module) with tensor state. @@ -1127,6 +1150,10 @@ class MockStepWithNonSerializableParam: def reset(self) -> None: pass + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + @ProcessorStepRegistry.register("registered_mock_step") @dataclass @@ -1162,6 +1189,10 @@ class RegisteredMockStep: def reset(self) -> None: pass + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + class MockEnvironment: """Mock environment for testing non-serializable parameters.""" @@ -1483,6 +1514,10 @@ class MockStepWithMixedState: "list_value": self.list_value, } + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + def test_to_device_with_mixed_state_types(): """Test that to() only moves tensor state, while non-tensor state remains in config.""" @@ -1790,6 +1825,10 @@ def test_state_file_naming_with_registry(): def load_state_dict(self, state): self.state_tensor = state["state_tensor"] + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + try: # Create pipeline with registered steps step1 = TestStatefulStep(1) @@ -1843,6 +1882,10 @@ def test_override_with_nested_config(): def get_config(self): return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + try: step = ComplexConfigStep() pipeline = RobotProcessor([step]) @@ -1931,6 +1974,10 @@ def test_override_with_callables(): def get_config(self): return {"name": self.name} + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + try: step = CallableStep() pipeline = RobotProcessor([step]) @@ -2059,6 +2106,10 @@ def test_override_with_device_strings(): def load_state_dict(self, state): self.buffer = state["buffer"] + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + try: step = DeviceAwareStep(device="cpu") pipeline = RobotProcessor([step]) @@ -2146,3 +2197,166 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) assert "observation.image" in result # Standard format preserved + + +class NonCompliantStep: + """Intentionally non-compliant: missing feature_contract.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + +def test_construction_rejects_step_without_feature_contract(): + with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): + RobotProcessor([NonCompliantStep()]) + + +class NonCallableStep: + """Intentionally non-compliant: missing __call__.""" + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +def test_construction_rejects_step_without_call(): + with pytest.raises(TypeError, match=r"must define __call__"): + RobotProcessor([NonCallableStep()]) + + +@dataclass +class FeatureContractAddStep: + """Adds a PolicyFeature""" + + key: str = "a" + value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.value + return features + + +@dataclass +class FeatureContractMutateStep: + """Mutates a PolicyFeature""" + + key: str = "a" + fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.fn(features.get(self.key)) + return features + + +@dataclass +class FeatureContractBadReturnStep: + """Returns a non-dict""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return ["not-a-dict"] + + +@dataclass +class FeatureContractRemoveStep: + """Removes a PolicyFeature""" + + key: str + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop(self.key, None) + return features + + +def test_feature_contract_orders_and_merges(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), + FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), + ] + ) + out = p.feature_contract({}) + + assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) + assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + assert_contract_is_typed(out) + + +def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): + initial = { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } + p = RobotProcessor( + [ + FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), + FeatureContractMutateStep( + "nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,)) + ), + ] + ) + out = p.feature_contract(initial_features=initial) + + assert out["seed"].shape == (8,) + assert out["nested"].shape == (5,) + # Initial dict must be preserved + assert initial["seed"].shape == (7,) + assert initial["nested"].shape == (0,) + + assert_contract_is_typed(out) + + +def test_feature_contract_type_error_on_bad_step(): + p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) + with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): + _ = p.feature_contract({}) + + +def test_feature_contract_execution_order_tracking(): + class Track: + def __init__(self, label): + self.label = label + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + code = {"A": 1, "B": 2, "C": 3}[self.label] + pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) + features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + return features + + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + assert out["order"].shape == (1, 2, 3) + + +def test_feature_contract_remove_key(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractRemoveStep("a"), + ] + ) + out = p.feature_contract({}) + assert "a" not in out + + +def test_feature_contract_remove_from_initial(policy_feature_factory): + initial = { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + } + p = RobotProcessor([FeatureContractRemoveStep("drop")]) + out = p.feature_contract(initial_features=initial) + assert "drop" not in out and out["keep"] == initial["keep"] diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index b310a7a8f..229d57f9f 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -19,7 +19,9 @@ from pathlib import Path import numpy as np import torch +from lerobot.configs.types import FeatureType from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from tests.conftest import assert_contract_is_typed def create_transition( @@ -406,3 +408,60 @@ def test_value_types_preserved(): assert processed_obs["old_string"] == "hello" assert processed_obs["old_dict"] == {"nested": "value"} assert processed_obs["old_list"] == [1, 2, 3] + + +def test_feature_contract_basic_renaming(policy_feature_factory): + processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (2,)), + "b": policy_feature_factory(FeatureType.ACTION, (3,)), + "c": policy_feature_factory(FeatureType.ENV, (1,)), + } + + out = processor.feature_contract(features.copy()) + + # Values preserved and typed + assert out["x"] == features["a"] + assert out["y"] == features["b"] + assert out["c"] == features["c"] + + assert_contract_is_typed(out) + # Input not mutated + assert set(features) == {"a", "b", "c"} + + +def test_feature_contract_overlapping_keys(policy_feature_factory): + # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' + processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (1,)), + "b": policy_feature_factory(FeatureType.STATE, (2,)), + } + out = processor.feature_contract(features) + + assert set(out) == {"b", "c"} + assert out["b"] == features["a"] # 'a' renamed to'b' + assert out["c"] == features["b"] # 'b' renamed to 'c' + assert_contract_is_typed(out) + + +def test_feature_contract_chained_processors(policy_feature_factory): + # Chain two rename processors at the contract level + processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameProcessor( + rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + ) + pipeline = RobotProcessor([processor1, processor2]) + + spec = { + "pos": policy_feature_factory(FeatureType.STATE, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = pipeline.feature_contract(initial_features=spec) + + assert set(out) == {"observation.state", "observation.image", "extra"} + assert out["observation.state"] == spec["pos"] + assert out["observation.image"] == spec["img"] + assert out["extra"] == spec["extra"] + assert_contract_is_typed(out)