mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes
This commit is contained in:
@@ -20,11 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from lerobot.processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
@@ -46,7 +42,7 @@ def create_transition(
|
||||
|
||||
def test_process_single_image():
|
||||
"""Test processing a single image."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -72,7 +68,7 @@ def test_process_single_image():
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create mock images
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -95,7 +91,7 @@ def test_process_image_dict():
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a batched image (B, H, W, C)
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
@@ -112,7 +108,7 @@ def test_process_batched_image():
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
@@ -125,7 +121,7 @@ def test_invalid_image_format():
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
@@ -163,7 +159,7 @@ def test_none_observation():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -182,7 +178,7 @@ def test_serialization_methods():
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
@@ -203,7 +199,7 @@ def test_process_environment_state():
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -224,7 +220,7 @@ def test_process_agent_pos():
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
@@ -242,7 +238,7 @@ def test_process_batched_states():
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
@@ -267,7 +263,7 @@ def test_process_both_states():
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -359,17 +355,6 @@ def test_empty_observation():
|
||||
assert processed_obs == {}
|
||||
|
||||
|
||||
def test_custom_sub_processors():
|
||||
"""Test ObservationProcessor with custom sub-processors."""
|
||||
image_proc = ImageProcessor()
|
||||
state_proc = StateProcessor()
|
||||
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
|
||||
# Should use the provided processors
|
||||
assert processor.image_processor is image_proc
|
||||
assert processor.state_processor is state_proc
|
||||
|
||||
|
||||
def test_equivalent_to_original_function():
|
||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||
# Import the original function for comparison
|
||||
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||
proc = StateProcessor()
|
||||
proc = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
|
||||
@@ -363,32 +363,6 @@ def test_hooks():
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
reset_called = []
|
||||
|
||||
def reset_hook():
|
||||
reset_called.append(True)
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
|
||||
# Make some calls to increment counter
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
pipeline(transition)
|
||||
|
||||
assert step.counter == 2
|
||||
|
||||
# Reset should reset step and call hook
|
||||
pipeline.reset()
|
||||
|
||||
assert step.counter == 0
|
||||
assert len(reset_called) == 1
|
||||
|
||||
|
||||
def test_unregister_hooks():
|
||||
"""Test unregistering hooks from the pipeline."""
|
||||
step = MockStep("test_step")
|
||||
@@ -428,21 +402,6 @@ def test_unregister_hooks():
|
||||
pipeline(transition)
|
||||
assert len(after_calls) == 0
|
||||
|
||||
# Test reset_hook
|
||||
reset_calls = []
|
||||
|
||||
def reset_hook():
|
||||
reset_calls.append(True)
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 1
|
||||
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
reset_calls.clear()
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 0
|
||||
|
||||
|
||||
def test_unregister_nonexistent_hook():
|
||||
"""Test error handling when unregistering hooks that don't exist."""
|
||||
@@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook():
|
||||
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
||||
pipeline.unregister_after_step_hook(some_hook)
|
||||
|
||||
with pytest.raises(ValueError, match="not found in reset_hooks"):
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
|
||||
|
||||
def test_multiple_hooks_and_selective_unregister():
|
||||
"""Test registering multiple hooks and selectively unregistering them."""
|
||||
@@ -552,22 +508,6 @@ def test_hook_execution_order_documentation():
|
||||
assert execution_order == ["A", "C", "B"] # B is now last
|
||||
|
||||
|
||||
def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = create_transition()
|
||||
|
||||
profile_results = pipeline.profile_steps(transition, num_runs=10)
|
||||
|
||||
assert len(profile_results) == 2
|
||||
assert "step_0_MockStep" in profile_results
|
||||
assert "step_1_MockStep" in profile_results
|
||||
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading pipeline.
|
||||
|
||||
@@ -581,7 +521,7 @@ def test_save_and_load_pretrained():
|
||||
step1.counter = 5
|
||||
step2.counter = 10
|
||||
|
||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42)
|
||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
@@ -596,7 +536,6 @@ def test_save_and_load_pretrained():
|
||||
config = json.load(f)
|
||||
|
||||
assert config["name"] == "TestPipeline"
|
||||
assert config["seed"] == 42
|
||||
assert len(config["steps"]) == 2
|
||||
|
||||
# Verify counters are saved in config, not in separate state files
|
||||
@@ -607,7 +546,6 @@ def test_save_and_load_pretrained():
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "TestPipeline"
|
||||
assert loaded_pipeline.seed == 42
|
||||
assert len(loaded_pipeline) == 2
|
||||
|
||||
# Check that counter was restored from config
|
||||
@@ -1255,10 +1193,10 @@ def test_repr_with_custom_name():
|
||||
def test_repr_with_seed():
|
||||
"""Test __repr__ with seed parameter."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step], seed=42)
|
||||
pipeline = RobotProcessor([step])
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)"
|
||||
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
@@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed():
|
||||
"""Test __repr__ with both custom name and seed."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStepWithoutOptionalMethods()
|
||||
pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123)
|
||||
pipeline = RobotProcessor([step1, step2], name="MyProcessor")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = (
|
||||
"RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)"
|
||||
)
|
||||
expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
def test_repr_without_seed():
|
||||
"""Test __repr__ when seed is explicitly None (should not show seed)."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step], name="TestProcessor", seed=None)
|
||||
pipeline = RobotProcessor([step], name="TestProcessor")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
|
||||
@@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names():
|
||||
step3 = MockStepWithTensorState()
|
||||
step4 = MockNonModuleStepWithState()
|
||||
|
||||
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999)
|
||||
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user