refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)

* refactor(processor): signature of transform_features

* refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly

* refactor(processor): rename now is only for visual

* refactor(processor): update normalize processor

* refactor(processor): update vanilla processor features

* refactor(processor): feature contract now uses its own enum

* chore(processor): rename renameprocessor

* chore(processor): minor changes

* refactor(processor): add create & change aggregate

* refactor(processor): update aggregate

* refactor(processor): simplify to functions, fix features contracts and rename function

* test(processor): remove to converter tests as now they are very simple

* chore(docs): recover docs joint observations processor

* fix(processor): update RKP

* fix(tests): recv diff test_pipeline

* chore(tests): add docs to test

* chore(processor): leave obs language constant untouched

* fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
Steven Palma
2025-09-09 18:27:30 +02:00
committed by GitHub
parent acf0ba7fb3
commit e881fb6678
47 changed files with 781 additions and 616 deletions
+88 -32
View File
@@ -18,7 +18,7 @@ import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType
from lerobot.configs.types import FeatureType, PipelineFeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import TransitionKey, VanillaObservationProcessorStep
from tests.conftest import assert_contract_is_typed
@@ -412,74 +412,130 @@ def test_equivalent_with_image_dict():
def test_image_processor_features_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessorStep()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
assert "pixels" not in out
assert out["keep"] == features["keep"]
assert (
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
== features[PipelineFeatureType.OBSERVATION]["pixels"]
)
assert "pixels" not in out[PipelineFeatureType.OBSERVATION]
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
processor = VanillaObservationProcessorStep()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
assert "observation.pixels" not in out
assert out["keep"] == features["keep"]
assert (
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
== features[PipelineFeatureType.OBSERVATION]["observation.pixels"]
)
assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION]
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
processor = VanillaObservationProcessorStep()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
PipelineFeatureType.OBSERVATION: {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
},
}
out = processor.transform_features(features.copy())
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
assert out["keep"] == features["keep"]
assert (
f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"]
== features[PipelineFeatureType.OBSERVATION]["pixels.front"]
)
assert (
f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"]
== features[PipelineFeatureType.OBSERVATION]["pixels.wrist"]
)
assert (
f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"]
== features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"]
)
assert (
"pixels.front" not in out[PipelineFeatureType.OBSERVATION]
and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION]
and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION]
)
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
processor = VanillaObservationProcessorStep()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
PipelineFeatureType.OBSERVATION: {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
},
}
out = processor.transform_features(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert out["keep"] == features["keep"]
assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["environment_state"]
)
assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
== features[PipelineFeatureType.OBSERVATION]["agent_pos"]
)
assert (
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
)
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
assert_contract_is_typed(out)
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessorStep()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
PipelineFeatureType.OBSERVATION: {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
},
}
out = proc.transform_features(features.copy())
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
assert "environment_state" not in out and "agent_pos" not in out
assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
)
assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"]
)
assert (
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
)
assert_contract_is_typed(out)