mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
@@ -18,7 +18,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import TransitionKey, VanillaObservationProcessorStep
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
@@ -412,74 +412,130 @@ def test_equivalent_with_image_dict():
|
||||
def test_image_processor_features_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||
assert "pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels"]
|
||||
)
|
||||
assert "pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||
assert "observation.pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels"]
|
||||
)
|
||||
assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
||||
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.front"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.wrist"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"]
|
||||
)
|
||||
assert (
|
||||
"pixels.front" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
out = proc.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
Reference in New Issue
Block a user