Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops
This commit is contained in:
Pepijn
2025-07-31 16:29:48 +02:00
committed by Adil Zouitine
parent 5ced72e6b8
commit 2c4e888c7f
9 changed files with 472 additions and 0 deletions
@@ -18,6 +18,7 @@ from typing import Any
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey from lerobot.processor.pipeline import EnvTransition, TransitionKey
@@ -74,3 +75,6 @@ class DeviceProcessor:
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.""" """Return configuration for serialization."""
return {"device": self.device} return {"device": self.device}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -204,6 +204,9 @@ class NormalizerProcessor:
def reset(self): def reset(self):
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor") @ProcessorStepRegistry.register(name="unnormalizer_processor")
@@ -327,3 +330,6 @@ class UnnormalizerProcessor:
def reset(self): def reset(self):
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@@ -110,6 +111,27 @@ class ImageProcessor:
"""Reset processor state (no-op for this processor).""" """Reset processor state (no-op for this processor)."""
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
pixels -> OBS_IMAGE,
observation.pixels -> OBS_IMAGE,
pixels.<cam> -> OBS_IMAGES.<cam>,
observation.pixels.<cam> -> OBS_IMAGES.<cam>
"""
if "pixels" in features:
features[OBS_IMAGE] = features.pop("pixels")
if "observation.pixels" in features:
features[OBS_IMAGE] = features.pop("observation.pixels")
prefixes = ("pixels.", "observation.pixels.")
for key in list(features.keys()):
for p in prefixes:
if key.startswith(p):
suffix = key[len(p) :]
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
break
return features
@dataclass @dataclass
class StateProcessor: class StateProcessor:
@@ -169,6 +191,25 @@ class StateProcessor:
"""Reset processor state (no-op for this processor).""" """Reset processor state (no-op for this processor)."""
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
environment_state -> OBS_ENV_STATE,
agent_pos -> OBS_STATE,
observation.environment_state -> OBS_ENV_STATE,
observation.agent_pos -> OBS_STATE
"""
pairs = (
("environment_state", OBS_ENV_STATE),
("agent_pos", OBS_STATE),
)
for old, new in pairs:
if old in features:
features[new] = features.pop(old)
prefixed = f"observation.{old}"
if prefixed in features:
features[new] = features.pop(prefixed)
return features
@dataclass @dataclass
@ProcessorStepRegistry.register(name="observation_processor") @ProcessorStepRegistry.register(name="observation_processor")
@@ -219,3 +260,8 @@ class VanillaObservationProcessor:
"""Reset processor state.""" """Reset processor state."""
self.image_processor.reset() self.image_processor.reset()
self.state_processor.reset() self.state_processor.reset()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features = self.image_processor.feature_contract(features)
features = self.state_processor.feature_contract(features)
return features
+39
View File
@@ -19,6 +19,7 @@ import importlib
import json import json
import os import os
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@@ -29,6 +30,7 @@ from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature
from lerobot.utils.utils import get_safe_torch_device from lerobot.utils.utils import get_safe_torch_device
@@ -141,6 +143,11 @@ class ProcessorStep(Protocol):
automatically serialise the step's configuration and learnable state using automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format. a safe-to-share JSON + SafeTensors format.
**Required**:
- ``__call__(transition: EnvTransition) -> EnvTransition``
- ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]``
Optional helper protocol: Optional helper protocol:
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable * ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all configuration and state. YOU decide what to save here. This is where all
@@ -168,6 +175,8 @@ class ProcessorStep(Protocol):
def reset(self) -> None: ... def reset(self) -> None: ...
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ...
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
"""Convert a *batch* dict coming from Learobot replay/dataset code into an """Convert a *batch* dict coming from Learobot replay/dataset code into an
@@ -840,6 +849,33 @@ class RobotProcessor(ModelHubMixin):
return f"RobotProcessor({', '.join(parts)})" return f"RobotProcessor({', '.join(parts)})"
def __post_init__(self):
for i, step in enumerate(self.steps):
if not callable(step):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
)
fc = getattr(step, "feature_contract", None)
if not callable(fc):
raise TypeError(
f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]"
)
def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Apply ALL steps in order. Each step must implement
feature_contract(features) and return a dict (full or incremental schema).
"""
features: dict[str, PolicyFeature] = deepcopy(initial_features)
for _, step in enumerate(self.steps):
out = step.feature_contract(features)
if not isinstance(out, dict):
raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]")
features = out
return features
class ObservationProcessor: class ObservationProcessor:
"""Base class for processors that modify only the observation component of a transition. """Base class for processors that modify only the observation component of a transition.
@@ -1145,3 +1181,6 @@ class IdentityProcessor:
def reset(self) -> None: def reset(self) -> None:
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -18,6 +18,7 @@ from typing import Any
import torch import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@@ -53,3 +54,10 @@ class RenameProcessor:
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}
+17
View File
@@ -19,6 +19,7 @@ import traceback
import pytest import pytest
from serial import SerialException from serial import SerialException
from lerobot.configs.types import FeatureType, PolicyFeature
from tests.utils import DEVICE from tests.utils import DEVICE
# Import fixture modules as plugins # Import fixture modules as plugins
@@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch):
print(text) print(text)
monkeypatch.setattr("builtins.input", print_text) monkeypatch.setattr("builtins.input", print_text)
@pytest.fixture
def policy_feature_factory():
"""PolicyFeature factory"""
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
return PolicyFeature(type=ft, shape=shape)
return _pf
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, str) for k in features.keys())
assert all(isinstance(v, PolicyFeature) for v in features.values())
@@ -18,12 +18,15 @@ import numpy as np
import pytest import pytest
import torch import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import ( from lerobot.processor import (
ImageProcessor, ImageProcessor,
StateProcessor, StateProcessor,
VanillaObservationProcessor, VanillaObservationProcessor,
) )
from lerobot.processor.pipeline import TransitionKey from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition( def create_transition(
@@ -420,3 +423,79 @@ def test_equivalent_with_image_dict():
for key in original_result: for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key]) torch.testing.assert_close(original_result[key], processor_result[key])
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = ImageProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
}
out = processor.feature_contract(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = StateProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert_contract_is_typed(out)
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = StateProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
}
out = proc.feature_contract(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert_contract_is_typed(out)
+214
View File
@@ -16,6 +16,7 @@
import json import json
import tempfile import tempfile
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -25,8 +26,10 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.pipeline import TransitionKey from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition( def create_transition(
@@ -88,6 +91,10 @@ class MockStep:
def reset(self) -> None: def reset(self) -> None:
self.counter = 0 self.counter = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@dataclass @dataclass
class MockStepWithoutOptionalMethods: class MockStepWithoutOptionalMethods:
@@ -106,6 +113,10 @@ class MockStepWithoutOptionalMethods:
return transition return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@dataclass @dataclass
class MockStepWithTensorState: class MockStepWithTensorState:
@@ -158,6 +169,10 @@ class MockStepWithTensorState:
self.running_mean.zero_() self.running_mean.zero_()
self.running_count.zero_() self.running_count.zero_()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_empty_pipeline(): def test_empty_pipeline():
"""Test pipeline with no steps.""" """Test pipeline with no steps."""
@@ -699,6 +714,10 @@ class MockModuleStep(nn.Module):
self.running_mean.zero_() self.running_mean.zero_()
self.counter = 0 self.counter = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_with_state_dict(): def test_to_device_with_state_dict():
"""Test moving pipeline to device for steps with state_dict.""" """Test moving pipeline to device for steps with state_dict."""
@@ -953,6 +972,10 @@ class MockNonModuleStepWithState:
self.step_count.zero_() self.step_count.zero_()
self.history.clear() self.history.clear()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_non_module_class(): def test_to_device_non_module_class():
"""Test moving pipeline to device for regular classes (non nn.Module) with tensor state. """Test moving pipeline to device for regular classes (non nn.Module) with tensor state.
@@ -1127,6 +1150,10 @@ class MockStepWithNonSerializableParam:
def reset(self) -> None: def reset(self) -> None:
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
@ProcessorStepRegistry.register("registered_mock_step") @ProcessorStepRegistry.register("registered_mock_step")
@dataclass @dataclass
@@ -1162,6 +1189,10 @@ class RegisteredMockStep:
def reset(self) -> None: def reset(self) -> None:
pass pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
class MockEnvironment: class MockEnvironment:
"""Mock environment for testing non-serializable parameters.""" """Mock environment for testing non-serializable parameters."""
@@ -1483,6 +1514,10 @@ class MockStepWithMixedState:
"list_value": self.list_value, "list_value": self.list_value,
} }
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
def test_to_device_with_mixed_state_types(): def test_to_device_with_mixed_state_types():
"""Test that to() only moves tensor state, while non-tensor state remains in config.""" """Test that to() only moves tensor state, while non-tensor state remains in config."""
@@ -1790,6 +1825,10 @@ def test_state_file_naming_with_registry():
def load_state_dict(self, state): def load_state_dict(self, state):
self.state_tensor = state["state_tensor"] self.state_tensor = state["state_tensor"]
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try: try:
# Create pipeline with registered steps # Create pipeline with registered steps
step1 = TestStatefulStep(1) step1 = TestStatefulStep(1)
@@ -1843,6 +1882,10 @@ def test_override_with_nested_config():
def get_config(self): def get_config(self):
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try: try:
step = ComplexConfigStep() step = ComplexConfigStep()
pipeline = RobotProcessor([step]) pipeline = RobotProcessor([step])
@@ -1931,6 +1974,10 @@ def test_override_with_callables():
def get_config(self): def get_config(self):
return {"name": self.name} return {"name": self.name}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try: try:
step = CallableStep() step = CallableStep()
pipeline = RobotProcessor([step]) pipeline = RobotProcessor([step])
@@ -2059,6 +2106,10 @@ def test_override_with_device_strings():
def load_state_dict(self, state): def load_state_dict(self, state):
self.buffer = state["buffer"] self.buffer = state["buffer"]
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We do not test feature_contract here
return features
try: try:
step = DeviceAwareStep(device="cpu") step = DeviceAwareStep(device="cpu")
pipeline = RobotProcessor([step]) pipeline = RobotProcessor([step])
@@ -2146,3 +2197,166 @@ def test_save_load_with_custom_converter_functions():
# Should work with standard format (wouldn't work with custom converter) # Should work with standard format (wouldn't work with custom converter)
result = loaded(batch) result = loaded(batch)
assert "observation.image" in result # Standard format preserved assert "observation.image" in result # Standard format preserved
class NonCompliantStep:
"""Intentionally non-compliant: missing feature_contract."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def test_construction_rejects_step_without_feature_contract():
with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"):
RobotProcessor([NonCompliantStep()])
class NonCallableStep:
"""Intentionally non-compliant: missing __call__."""
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def test_construction_rejects_step_without_call():
with pytest.raises(TypeError, match=r"must define __call__"):
RobotProcessor([NonCallableStep()])
@dataclass
class FeatureContractAddStep:
"""Adds a PolicyFeature"""
key: str = "a"
value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,))
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.value
return features
@dataclass
class FeatureContractMutateStep:
"""Mutates a PolicyFeature"""
key: str = "a"
fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.fn(features.get(self.key))
return features
@dataclass
class FeatureContractBadReturnStep:
"""Returns a non-dict"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return ["not-a-dict"]
@dataclass
class FeatureContractRemoveStep:
"""Removes a PolicyFeature"""
key: str
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(self.key, None)
return features
def test_feature_contract_orders_and_merges(policy_feature_factory):
p = RobotProcessor(
[
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))),
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
]
)
out = p.feature_contract({})
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
assert_contract_is_typed(out)
def test_feature_contract_respects_initial_without_mutation(policy_feature_factory):
initial = {
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
}
p = RobotProcessor(
[
FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))),
FeatureContractMutateStep(
"nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,))
),
]
)
out = p.feature_contract(initial_features=initial)
assert out["seed"].shape == (8,)
assert out["nested"].shape == (5,)
# Initial dict must be preserved
assert initial["seed"].shape == (7,)
assert initial["nested"].shape == (0,)
assert_contract_is_typed(out)
def test_feature_contract_type_error_on_bad_step():
p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()])
with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"):
_ = p.feature_contract({})
def test_feature_contract_execution_order_tracking():
class Track:
def __init__(self, label):
self.label = label
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
code = {"A": 1, "B": 2, "C": 3}[self.label]
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
return features
out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({})
assert out["order"].shape == (1, 2, 3)
def test_feature_contract_remove_key(policy_feature_factory):
p = RobotProcessor(
[
FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))),
FeatureContractRemoveStep("a"),
]
)
out = p.feature_contract({})
assert "a" not in out
def test_feature_contract_remove_from_initial(policy_feature_factory):
initial = {
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
}
p = RobotProcessor([FeatureContractRemoveStep("drop")])
out = p.feature_contract(initial_features=initial)
assert "drop" not in out and out["keep"] == initial["keep"]
+59
View File
@@ -19,7 +19,9 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from lerobot.configs.types import FeatureType
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
from tests.conftest import assert_contract_is_typed
def create_transition( def create_transition(
@@ -406,3 +408,60 @@ def test_value_types_preserved():
assert processed_obs["old_string"] == "hello" assert processed_obs["old_string"] == "hello"
assert processed_obs["old_dict"] == {"nested": "value"} assert processed_obs["old_dict"] == {"nested": "value"}
assert processed_obs["old_list"] == [1, 2, 3] assert processed_obs["old_list"] == [1, 2, 3]
def test_feature_contract_basic_renaming(policy_feature_factory):
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (2,)),
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
"c": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = processor.feature_contract(features.copy())
# Values preserved and typed
assert out["x"] == features["a"]
assert out["y"] == features["b"]
assert out["c"] == features["c"]
assert_contract_is_typed(out)
# Input not mutated
assert set(features) == {"a", "b", "c"}
def test_feature_contract_overlapping_keys(policy_feature_factory):
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
features = {
"a": policy_feature_factory(FeatureType.STATE, (1,)),
"b": policy_feature_factory(FeatureType.STATE, (2,)),
}
out = processor.feature_contract(features)
assert set(out) == {"b", "c"}
assert out["b"] == features["a"] # 'a' renamed to'b'
assert out["c"] == features["b"] # 'b' renamed to 'c'
assert_contract_is_typed(out)
def test_feature_contract_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameProcessor(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
)
pipeline = RobotProcessor([processor1, processor2])
spec = {
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
}
out = pipeline.feature_contract(initial_features=spec)
assert set(out) == {"observation.state", "observation.image", "extra"}
assert out["observation.state"] == spec["pos"]
assert out["observation.image"] == spec["img"]
assert out["extra"] == spec["extra"]
assert_contract_is_typed(out)