mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
@@ -19,11 +19,11 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
@@ -51,7 +51,7 @@ def test_basic_renaming():
|
||||
"old_key1": "new_key1",
|
||||
"old_key2": "new_key2",
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"old_key1": torch.tensor([1.0, 2.0]),
|
||||
@@ -79,7 +79,7 @@ def test_basic_renaming():
|
||||
|
||||
def test_empty_rename_map():
|
||||
"""Test processor with empty rename map (should pass through unchanged)."""
|
||||
processor = RenameProcessorStep(rename_map={})
|
||||
processor = RenameObservationsProcessorStep(rename_map={})
|
||||
|
||||
observation = {
|
||||
"key1": torch.tensor([1.0]),
|
||||
@@ -98,7 +98,7 @@ def test_empty_rename_map():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessorStep(rename_map={"old": "new"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
@@ -113,7 +113,7 @@ def test_overlapping_rename():
|
||||
"a": "b",
|
||||
"b": "c", # This creates a potential conflict
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"a": 1,
|
||||
@@ -138,7 +138,7 @@ def test_partial_rename():
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
@@ -168,7 +168,7 @@ def test_get_config():
|
||||
"old1": "new1",
|
||||
"old2": "new2",
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"rename_map": rename_map}
|
||||
@@ -176,7 +176,7 @@ def test_get_config():
|
||||
|
||||
def test_state_dict():
|
||||
"""Test state dict (should be empty for RenameProcessorStep)."""
|
||||
processor = RenameProcessorStep(rename_map={"old": "new"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
@@ -191,7 +191,7 @@ def test_integration_with_robot_processor():
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
rename_processor = RenameProcessorStep(rename_map=rename_map)
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
pipeline = DataProcessorPipeline([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
@@ -225,7 +225,7 @@ def test_save_and_load_pretrained():
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -252,7 +252,7 @@ def test_save_and_load_pretrained():
|
||||
|
||||
# Check that loaded processor works correctly
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessorStep)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == rename_map
|
||||
|
||||
# Test functionality after loading
|
||||
@@ -271,21 +271,21 @@ def test_save_and_load_pretrained():
|
||||
def test_registry_functionality():
|
||||
"""Test that RenameProcessorStep is properly registered."""
|
||||
# Check that it's registered
|
||||
assert "rename_processor" in ProcessorStepRegistry.list()
|
||||
assert "rename_observations_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
# Get from registry
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
||||
assert retrieved_class is RenameProcessorStep
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_observations_processor")
|
||||
assert retrieved_class is RenameObservationsProcessorStep
|
||||
|
||||
# Create instance from registry
|
||||
instance = retrieved_class(rename_map={"old": "new"})
|
||||
assert isinstance(instance, RenameProcessorStep)
|
||||
assert isinstance(instance, RenameObservationsProcessorStep)
|
||||
assert instance.rename_map == {"old": "new"}
|
||||
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessorStep(rename_map={"key1": "renamed_key1"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -299,20 +299,20 @@ def test_registry_based_save_load():
|
||||
config = json.load(f)
|
||||
|
||||
assert "registry_name" in config["steps"][0]
|
||||
assert config["steps"][0]["registry_name"] == "rename_processor"
|
||||
assert config["steps"][0]["registry_name"] == "rename_observations_processor"
|
||||
assert "class" not in config["steps"][0] # Should use registry, not module path
|
||||
|
||||
# Load should work
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(tmp_dir)
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessorStep)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
||||
|
||||
|
||||
def test_chained_rename_processors():
|
||||
"""Test multiple RenameProcessorSteps in a pipeline."""
|
||||
# First processor: rename raw keys to intermediate format
|
||||
processor1 = RenameProcessorStep(
|
||||
processor1 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"pos": "agent_position",
|
||||
"img": "camera_image",
|
||||
@@ -320,7 +320,7 @@ def test_chained_rename_processors():
|
||||
)
|
||||
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameProcessorStep(
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
@@ -365,7 +365,7 @@ def test_nested_observation_rename():
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
@@ -395,7 +395,7 @@ def test_nested_observation_rename():
|
||||
def test_value_types_preserved():
|
||||
"""Test that various value types are preserved during renaming."""
|
||||
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
||||
processor = RenameProcessorStep(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
tensor_value = torch.randn(3, 3)
|
||||
array_value = np.random.rand(2, 2)
|
||||
@@ -423,59 +423,75 @@ def test_value_types_preserved():
|
||||
|
||||
|
||||
def test_features_basic_renaming(policy_feature_factory):
|
||||
processor = RenameProcessorStep(rename_map={"a": "x", "b": "y"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
# Values preserved and typed
|
||||
assert out["x"] == features["a"]
|
||||
assert out["y"] == features["b"]
|
||||
assert out["c"] == features["c"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"]
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
# Input not mutated
|
||||
assert set(features) == {"a", "b", "c"}
|
||||
assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_features_overlapping_keys(policy_feature_factory):
|
||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||
processor = RenameProcessorStep(rename_map={"a": "b", "b": "c"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features)
|
||||
|
||||
assert set(out) == {"b", "c"}
|
||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
) # 'a' renamed to'b'
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
) # 'b' renamed to 'c'
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameProcessorStep(
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
)
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
spec = {
|
||||
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pos": policy_feature_factory(FeatureType.VISUAL, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||
assert out["observation.state"] == spec["pos"]
|
||||
assert out["observation.image"] == spec["img"]
|
||||
assert out["extra"] == spec["extra"]
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user