Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops
This commit is contained in:
Pepijn
2025-07-31 16:29:48 +02:00
committed by Adil Zouitine
parent 5ced72e6b8
commit 2c4e888c7f
9 changed files with 472 additions and 0 deletions
@@ -18,12 +18,15 @@ import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import (
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
@@ -420,3 +423,79 @@ def test_equivalent_with_image_dict():
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = ImageProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
}
out = processor.feature_contract(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = StateProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = StateProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
}
out = proc.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert_contract_is_typed(out)
+214
View File
@@ -16,6 +16,7 @@
import json
import tempfile
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -25,8 +26,10 @@ import pytest
import torch
import torch.nn as nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
@@ -88,6 +91,10 @@ class MockStep:
def reset(self) -> None:
self.counter = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@dataclass
class MockStepWithoutOptionalMethods:
@@ -106,6 +113,10 @@ class MockStepWithoutOptionalMethods:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@dataclass
class MockStepWithTensorState:
@@ -158,6 +169,10 @@ class MockStepWithTensorState:
self.running_mean.zero_()
self.running_count.zero_()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_empty_pipeline():
"""Test pipeline with no steps."""
@@ -699,6 +714,10 @@ class MockModuleStep(nn.Module):
self.running_mean.zero_()
self.counter = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_with_state_dict():
"""Test moving pipeline to device for steps with state_dict."""
@@ -953,6 +972,10 @@ class MockNonModuleStepWithState:
self.step_count.zero_()
self.history.clear()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_non_module_class():
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
@@ -1127,6 +1150,10 @@ class MockStepWithNonSerializableParam:
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@ProcessorStepRegistry.register("registered_mock_step")
@dataclass
@@ -1162,6 +1189,10 @@ class RegisteredMockStep:
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
class MockEnvironment:
"""Mock environment for testing non-serializable parameters."""
@@ -1483,6 +1514,10 @@ class MockStepWithMixedState:
"list_value": self.list_value,
}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_with_mixed_state_types():
"""Test that to() only moves tensor state, while non-tensor state remains in config."""
@@ -1790,6 +1825,10 @@ def test_state_file_naming_with_registry():
def load_state_dict(self, state):
self.state_tensor = state["state_tensor"]
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try:
# Create pipeline with registered steps
step1 = TestStatefulStep(1)
@@ -1843,6 +1882,10 @@ def test_override_with_nested_config():
def get_config(self):
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try:
step = ComplexConfigStep()
pipeline = RobotProcessor([step])
@@ -1931,6 +1974,10 @@ def test_override_with_callables():
def get_config(self):
return {"name": self.name}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try:
step = CallableStep()
pipeline = RobotProcessor([step])
@@ -2059,6 +2106,10 @@ def test_override_with_device_strings():
def load_state_dict(self, state):
self.buffer = state["buffer"]
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try:
step = DeviceAwareStep(device="cpu")
pipeline = RobotProcessor([step])
@@ -2146,3 +2197,166 @@ def test_save_load_with_custom_converter_functions():
# Should work with standard format (wouldn't work with custom converter)
result = loaded(batch)
assert "observation.image" in result # Standard format preserved
class NonCompliantStep:
"""Intentionally non-compliant: missing feature_contract."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def test_construction_rejects_step_without_feature_contract():
with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"):
RobotProcessor([NonCompliantStep()])
class NonCallableStep:
"""Intentionally non-compliant: missing __call__."""
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def test_construction_rejects_step_without_call():
with pytest.raises(TypeError, match=r"must define __call__"):
RobotProcessor([NonCallableStep()])
@dataclass
class FeatureContractAddStep:
"""Adds a PolicyFeature"""
key: str = "a"
value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,))
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.value
return features
@dataclass
class FeatureContractMutateStep:
"""Mutates a PolicyFeature"""
key: str = "a"
fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.fn(features.get(self.key))
return features
@dataclass
class FeatureContractBadReturnStep:
"""Returns a non-dict"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return ["not-a-dict"]
@dataclass
class FeatureContractRemoveStep:
"""Removes a PolicyFeature"""
key: str
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(self.key, None)
return features
def test_feature_contract_orders_and_merges(policy_feature_factory):
p = RobotProcessor(
[
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))),
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
]
)
out = p.feature_contract({})
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
assert_contract_is_typed(out)
def test_feature_contract_respects_initial_without_mutation(policy_feature_factory):
initial = {
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
}
p = RobotProcessor(
[
FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))),
FeatureContractMutateStep(
"nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,))
),
]
)
out = p.feature_contract(initial_features=initial)
assert out["seed"].shape == (8,)
assert out["nested"].shape == (5,)
# Initial dict must be preserved
assert initial["seed"].shape == (7,)
assert initial["nested"].shape == (0,)
assert_contract_is_typed(out)
def test_feature_contract_type_error_on_bad_step():
p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()])
with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"):
_ = p.feature_contract({})
def test_feature_contract_execution_order_tracking():
class Track:
def __init__(self, label):
self.label = label
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
code = {"A": 1, "B": 2, "C": 3}[self.label]
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
return features
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({})
assert out["order"].shape == (1, 2, 3)
def test_feature_contract_remove_key(policy_feature_factory):
p = RobotProcessor(
[
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
FeatureContractRemoveStep("a"),
]
)
out = p.feature_contract({})
assert "a" not in out
def test_feature_contract_remove_from_initial(policy_feature_factory):
initial = {
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
}
p = RobotProcessor([FeatureContractRemoveStep("drop")])
out = p.feature_contract(initial_features=initial)
assert "drop" not in out and out["keep"] == initial["keep"]
+59
View File
@@ -19,7 +19,9 @@ from pathlib import Path
import numpy as np
import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition(
@@ -406,3 +408,60 @@ def test_value_types_preserved():
assert processed_obs["old_string"] == "hello"
assert processed_obs["old_dict"] == {"nested": "value"}
assert processed_obs["old_list"] == [1, 2, 3]
def test_feature_contract_basic_renaming(policy_feature_factory):
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (2,)),
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
"c": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
# Values preserved and typed
assert out["x"] == features["a"]
assert out["y"] == features["b"]
assert out["c"] == features["c"]
assert_contract_is_typed(out)
# Input not mutated
assert set(features) == {"a", "b", "c"}
def test_feature_contract_overlapping_keys(policy_feature_factory):
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (1,)),
"b": policy_feature_factory(FeatureType.STATE, (2,)),
}
out = processor.feature_contract(features)
assert set(out) == {"b", "c"}
assert out["b"] == features["a"] # 'a' renamed to'b'
assert out["c"] == features["b"] # 'b' renamed to 'c'
assert_contract_is_typed(out)
def test_feature_contract_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameProcessor(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
)
pipeline = RobotProcessor([processor1, processor2])
spec = {
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = pipeline.feature_contract(initial_features=spec)
assert set(out) == {"observation.state", "observation.image", "extra"}
assert out["observation.state"] == spec["pos"]
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert_contract_is_typed(out)