refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)

* refactor(processor): signature of transform_features

* refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly

* refactor(processor): rename now is only for visual

* refactor(processor): update normalize processor

* refactor(processor): update vanilla processor features

* refactor(processor): feature contract now uses its own enum

* chore(processor): rename renameprocessor

* chore(processor): minor changes

* refactor(processor): add create & change aggregate

* refactor(processor): update aggregate

* refactor(processor): simplify to functions, fix features contracts and rename function

* test(processor): remove to converter tests as now they are very simple

* chore(docs): recover docs joint observations processor

* fix(processor): update RKP

* fix(tests): recv diff test_pipeline

* chore(tests): add docs to test

* chore(processor): leave obs language constant untouched

* fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
Steven Palma
2025-09-09 18:27:30 +02:00
committed by GitHub
parent acf0ba7fb3
commit e881fb6678
47 changed files with 781 additions and 616 deletions
+116 -58
View File
@@ -25,7 +25,7 @@ import pytest
import torch
import torch.nn as nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.processor import (
DataProcessorPipeline,
@@ -96,7 +96,9 @@ class MockStep(ProcessorStep):
def reset(self) -> None:
self.counter = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -118,7 +120,9 @@ class MockStepWithoutOptionalMethods(ProcessorStep):
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -174,7 +178,9 @@ class MockStepWithTensorState(ProcessorStep):
self.running_mean.zero_()
self.running_count.zero_()
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -670,7 +676,9 @@ class MockModuleStep(ProcessorStep, nn.Module):
self.running_mean.zero_()
self.counter = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -752,7 +760,9 @@ class MockNonModuleStepWithState(ProcessorStep):
self.step_count.zero_()
self.history.clear()
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -807,7 +817,9 @@ class MockStepWithNonSerializableParam(ProcessorStep):
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -846,7 +858,9 @@ class RegisteredMockStep(ProcessorStep):
def reset(self) -> None:
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -1406,7 +1420,9 @@ def test_state_file_naming_with_registry():
def load_state_dict(self, state):
self.state_tensor = state["state_tensor"]
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -1463,7 +1479,9 @@ def test_override_with_nested_config():
def get_config(self):
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -1557,7 +1575,9 @@ def test_override_with_callables():
def get_config(self):
return {"name": self.name}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -1692,7 +1712,9 @@ def test_override_with_device_strings():
def load_state_dict(self, state):
self.buffer = state["buffer"]
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# We do not test features here
return features
@@ -1805,16 +1827,20 @@ class NonCompliantStep:
return transition
class NonCallableStep:
class NonCallableStep(ProcessorStep):
"""Intentionally non-compliant: missing __call__."""
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def test_construction_rejects_step_without_processorstep():
def test_construction_rejects_step_without_call():
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
with pytest.raises(
TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_"
):
DataProcessorPipeline([NonCallableStep()])
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
@@ -1831,8 +1857,10 @@ class FeatureContractAddStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.value
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.OBSERVATION][self.key] = self.value
return features
@@ -1846,8 +1874,12 @@ class FeatureContractMutateStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[self.key] = self.fn(features.get(self.key))
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.OBSERVATION][self.key] = self.fn(
features[PipelineFeatureType.OBSERVATION].get(self.key)
)
return features
@@ -1858,7 +1890,9 @@ class FeatureContractBadReturnStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return ["not-a-dict"]
@@ -1871,8 +1905,10 @@ class FeatureContractRemoveStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(self.key, None)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.OBSERVATION].pop(self.key, None)
return features
@@ -1884,17 +1920,22 @@ def test_features_orders_and_merges(policy_feature_factory):
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
]
)
out = p.transform_features({})
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
out = p.transform_features({PipelineFeatureType.OBSERVATION: {}})
assert out[PipelineFeatureType.OBSERVATION]["a"].type == FeatureType.STATE and out[
PipelineFeatureType.OBSERVATION
]["a"].shape == (3,)
assert out[PipelineFeatureType.OBSERVATION]["b"].type == FeatureType.ENV and out[
PipelineFeatureType.OBSERVATION
]["b"].shape == (2,)
assert_contract_is_typed(out)
def test_features_respects_initial_without_mutation(policy_feature_factory):
initial = {
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
PipelineFeatureType.OBSERVATION: {
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
}
}
p = DataProcessorPipeline(
[
@@ -1906,11 +1947,11 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
)
out = p.transform_features(initial_features=initial)
assert out["seed"].shape == (8,)
assert out["nested"].shape == (5,)
assert out[PipelineFeatureType.OBSERVATION]["seed"].shape == (8,)
assert out[PipelineFeatureType.OBSERVATION]["nested"].shape == (5,)
# Initial dict must be preserved
assert initial["seed"].shape == (7,)
assert initial["nested"].shape == (0,)
assert initial[PipelineFeatureType.OBSERVATION]["seed"].shape == (7,)
assert initial[PipelineFeatureType.OBSERVATION]["nested"].shape == (0,)
assert_contract_is_typed(out)
@@ -1923,14 +1964,22 @@ def test_features_execution_order_tracking():
def __call__(self, transition: EnvTransition) -> EnvTransition:
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
code = {"A": 1, "B": 2, "C": 3}[self.label]
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
pf = features[PipelineFeatureType.OBSERVATION].get(
"order", PolicyFeature(type=FeatureType.ENV, shape=())
)
features[PipelineFeatureType.OBSERVATION]["order"] = PolicyFeature(
type=pf.type, shape=pf.shape + (code,)
)
return features
out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features({})
assert out["order"].shape == (1, 2, 3)
out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features(
initial_features={PipelineFeatureType.OBSERVATION: {}}
)
assert out[PipelineFeatureType.OBSERVATION]["order"].shape == (1, 2, 3)
def test_features_remove_key(policy_feature_factory):
@@ -1940,18 +1989,23 @@ def test_features_remove_key(policy_feature_factory):
FeatureContractRemoveStep("a"),
]
)
out = p.transform_features({})
assert "a" not in out
out = p.transform_features({PipelineFeatureType.OBSERVATION: {}})
assert "a" not in out[PipelineFeatureType.OBSERVATION]
def test_features_remove_from_initial(policy_feature_factory):
initial = {
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
PipelineFeatureType.OBSERVATION: {
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
},
}
p = DataProcessorPipeline([FeatureContractRemoveStep("drop")])
out = p.transform_features(initial_features=initial)
assert "drop" not in out and out["keep"] == initial["keep"]
assert (
"drop" not in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION]["keep"] == initial[PipelineFeatureType.OBSERVATION]["keep"]
)
@dataclass
@@ -1961,13 +2015,15 @@ class AddActionEEAndJointFeatures(ProcessorStep):
def __call__(self, tr):
return tr
def transform_features(self, features: dict) -> dict:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# EE features
features["action.ee.x"] = float
features["action.ee.y"] = float
features[PipelineFeatureType.ACTION]["action.ee.x"] = float
features[PipelineFeatureType.ACTION]["action.ee.y"] = float
# JOINT features
features["action.j1.pos"] = float
features["action.j2.pos"] = float
features[PipelineFeatureType.ACTION]["action.j1.pos"] = float
features[PipelineFeatureType.ACTION]["action.j2.pos"] = float
return features
@@ -1981,18 +2037,20 @@ class AddObservationStateFeatures(ProcessorStep):
def __call__(self, tr):
return tr
def transform_features(self, features: dict) -> dict:
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# State features (mix EE and a joint state)
features["observation.state.ee.x"] = float
features["observation.state.j1.pos"] = float
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float
if self.add_front_image:
features["observation.images.front"] = self.front_image_shape
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape
return features
def test_aggregate_joint_action_only():
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
initial = {"front": (480, 640, 3)}
initial = {PipelineFeatureType.OBSERVATION: {"front": (480, 640, 3)}, PipelineFeatureType.ACTION: {}}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
@@ -2014,7 +2072,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features=initial,
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
use_videos=True,
patterns=["action.ee", "observation.state"],
)
@@ -2042,7 +2100,7 @@ def test_aggregate_both_action_types():
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={},
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}},
use_videos=True,
patterns=["action.ee", "action.j1", "action.j2.pos"],
)
@@ -2059,7 +2117,7 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features=initial,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # expect "image" dtype
patterns=None,
)
@@ -2076,7 +2134,7 @@ def test_aggregate_images_when_use_videos_true():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features=initial,
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
use_videos=True,
patterns=None,
)
@@ -2100,7 +2158,7 @@ def test_initial_camera_not_overridden_by_step_image():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features=initial,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=True,
patterns=["observation.images.front"],
)