mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
@@ -25,7 +25,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
@@ -96,7 +96,9 @@ class MockStep(ProcessorStep):
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -118,7 +120,9 @@ class MockStepWithoutOptionalMethods(ProcessorStep):
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -174,7 +178,9 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
self.running_mean.zero_()
|
||||
self.running_count.zero_()
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -670,7 +676,9 @@ class MockModuleStep(ProcessorStep, nn.Module):
|
||||
self.running_mean.zero_()
|
||||
self.counter = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -752,7 +760,9 @@ class MockNonModuleStepWithState(ProcessorStep):
|
||||
self.step_count.zero_()
|
||||
self.history.clear()
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -807,7 +817,9 @@ class MockStepWithNonSerializableParam(ProcessorStep):
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -846,7 +858,9 @@ class RegisteredMockStep(ProcessorStep):
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -1406,7 +1420,9 @@ def test_state_file_naming_with_registry():
|
||||
def load_state_dict(self, state):
|
||||
self.state_tensor = state["state_tensor"]
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -1463,7 +1479,9 @@ def test_override_with_nested_config():
|
||||
def get_config(self):
|
||||
return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -1557,7 +1575,9 @@ def test_override_with_callables():
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -1692,7 +1712,9 @@ def test_override_with_device_strings():
|
||||
def load_state_dict(self, state):
|
||||
self.buffer = state["buffer"]
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We do not test features here
|
||||
return features
|
||||
|
||||
@@ -1805,16 +1827,20 @@ class NonCompliantStep:
|
||||
return transition
|
||||
|
||||
|
||||
class NonCallableStep:
|
||||
class NonCallableStep(ProcessorStep):
|
||||
"""Intentionally non-compliant: missing __call__."""
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def test_construction_rejects_step_without_processorstep():
|
||||
def test_construction_rejects_step_without_call():
|
||||
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
with pytest.raises(
|
||||
TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_"
|
||||
):
|
||||
DataProcessorPipeline([NonCallableStep()])
|
||||
|
||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||
@@ -1831,8 +1857,10 @@ class FeatureContractAddStep(ProcessorStep):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.value
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.key] = self.value
|
||||
return features
|
||||
|
||||
|
||||
@@ -1846,8 +1874,12 @@ class FeatureContractMutateStep(ProcessorStep):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[self.key] = self.fn(features.get(self.key))
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.OBSERVATION][self.key] = self.fn(
|
||||
features[PipelineFeatureType.OBSERVATION].get(self.key)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@@ -1858,7 +1890,9 @@ class FeatureContractBadReturnStep(ProcessorStep):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return ["not-a-dict"]
|
||||
|
||||
|
||||
@@ -1871,8 +1905,10 @@ class FeatureContractRemoveStep(ProcessorStep):
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(self.key, None)
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.OBSERVATION].pop(self.key, None)
|
||||
return features
|
||||
|
||||
|
||||
@@ -1884,17 +1920,22 @@ def test_features_orders_and_merges(policy_feature_factory):
|
||||
FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))),
|
||||
]
|
||||
)
|
||||
out = p.transform_features({})
|
||||
|
||||
assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,)
|
||||
assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,)
|
||||
out = p.transform_features({PipelineFeatureType.OBSERVATION: {}})
|
||||
assert out[PipelineFeatureType.OBSERVATION]["a"].type == FeatureType.STATE and out[
|
||||
PipelineFeatureType.OBSERVATION
|
||||
]["a"].shape == (3,)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["b"].type == FeatureType.ENV and out[
|
||||
PipelineFeatureType.OBSERVATION
|
||||
]["b"].shape == (2,)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_features_respects_initial_without_mutation(policy_feature_factory):
|
||||
initial = {
|
||||
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"seed": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"nested": policy_feature_factory(FeatureType.ENV, (0,)),
|
||||
}
|
||||
}
|
||||
p = DataProcessorPipeline(
|
||||
[
|
||||
@@ -1906,11 +1947,11 @@ def test_features_respects_initial_without_mutation(policy_feature_factory):
|
||||
)
|
||||
out = p.transform_features(initial_features=initial)
|
||||
|
||||
assert out["seed"].shape == (8,)
|
||||
assert out["nested"].shape == (5,)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["seed"].shape == (8,)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["nested"].shape == (5,)
|
||||
# Initial dict must be preserved
|
||||
assert initial["seed"].shape == (7,)
|
||||
assert initial["nested"].shape == (0,)
|
||||
assert initial[PipelineFeatureType.OBSERVATION]["seed"].shape == (7,)
|
||||
assert initial[PipelineFeatureType.OBSERVATION]["nested"].shape == (0,)
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
@@ -1923,14 +1964,22 @@ def test_features_execution_order_tracking():
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
code = {"A": 1, "B": 2, "C": 3}[self.label]
|
||||
pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=()))
|
||||
features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,))
|
||||
pf = features[PipelineFeatureType.OBSERVATION].get(
|
||||
"order", PolicyFeature(type=FeatureType.ENV, shape=())
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["order"] = PolicyFeature(
|
||||
type=pf.type, shape=pf.shape + (code,)
|
||||
)
|
||||
return features
|
||||
|
||||
out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features({})
|
||||
assert out["order"].shape == (1, 2, 3)
|
||||
out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features(
|
||||
initial_features={PipelineFeatureType.OBSERVATION: {}}
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["order"].shape == (1, 2, 3)
|
||||
|
||||
|
||||
def test_features_remove_key(policy_feature_factory):
|
||||
@@ -1940,18 +1989,23 @@ def test_features_remove_key(policy_feature_factory):
|
||||
FeatureContractRemoveStep("a"),
|
||||
]
|
||||
)
|
||||
out = p.transform_features({})
|
||||
assert "a" not in out
|
||||
out = p.transform_features({PipelineFeatureType.OBSERVATION: {}})
|
||||
assert "a" not in out[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
|
||||
def test_features_remove_from_initial(policy_feature_factory):
|
||||
initial = {
|
||||
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"keep": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"drop": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
},
|
||||
}
|
||||
p = DataProcessorPipeline([FeatureContractRemoveStep("drop")])
|
||||
out = p.transform_features(initial_features=initial)
|
||||
assert "drop" not in out and out["keep"] == initial["keep"]
|
||||
assert (
|
||||
"drop" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION]["keep"] == initial[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1961,13 +2015,15 @@ class AddActionEEAndJointFeatures(ProcessorStep):
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(self, features: dict) -> dict:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# EE features
|
||||
features["action.ee.x"] = float
|
||||
features["action.ee.y"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.ee.x"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.ee.y"] = float
|
||||
# JOINT features
|
||||
features["action.j1.pos"] = float
|
||||
features["action.j2.pos"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.j1.pos"] = float
|
||||
features[PipelineFeatureType.ACTION]["action.j2.pos"] = float
|
||||
return features
|
||||
|
||||
|
||||
@@ -1981,18 +2037,20 @@ class AddObservationStateFeatures(ProcessorStep):
|
||||
def __call__(self, tr):
|
||||
return tr
|
||||
|
||||
def transform_features(self, features: dict) -> dict:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# State features (mix EE and a joint state)
|
||||
features["observation.state.ee.x"] = float
|
||||
features["observation.state.j1.pos"] = float
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features["observation.images.front"] = self.front_image_shape
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
def test_aggregate_joint_action_only():
|
||||
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
|
||||
initial = {"front": (480, 640, 3)}
|
||||
initial = {PipelineFeatureType.OBSERVATION: {"front": (480, 640, 3)}, PipelineFeatureType.ACTION: {}}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
@@ -2014,7 +2072,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "observation.state"],
|
||||
)
|
||||
@@ -2042,7 +2100,7 @@ def test_aggregate_both_action_types():
|
||||
rp = DataProcessorPipeline([AddActionEEAndJointFeatures()])
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={},
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.j1", "action.j2.pos"],
|
||||
)
|
||||
@@ -2059,7 +2117,7 @@ def test_aggregate_images_when_use_videos_false():
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=False, # expect "image" dtype
|
||||
patterns=None,
|
||||
)
|
||||
@@ -2076,7 +2134,7 @@ def test_aggregate_images_when_use_videos_true():
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=None,
|
||||
)
|
||||
@@ -2100,7 +2158,7 @@ def test_initial_camera_not_overridden_by_step_image():
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features=initial,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=True,
|
||||
patterns=["observation.images.front"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user