refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes
This commit is contained in:
Steven Palma
2025-08-06 14:00:13 +02:00
committed by GitHub
parent 7beb040e8e
commit fd4ae3466b
8 changed files with 165 additions and 421 deletions
+19 -34
View File
@@ -20,11 +20,7 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import (
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor import VanillaObservationProcessor
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
@@ -46,7 +42,7 @@ def create_transition(
def test_process_single_image():
"""Test processing a single image."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -72,7 +68,7 @@ def test_process_single_image():
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -95,7 +91,7 @@ def test_process_image_dict():
def test_process_batched_image():
"""Test processing already batched images."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
@@ -112,7 +108,7 @@ def test_process_batched_image():
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
@@ -125,7 +121,7 @@ def test_invalid_image_format():
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
transition = create_transition()
result = processor(transition)
@@ -163,7 +159,7 @@ def test_none_observation():
def test_serialization_methods():
"""Test serialization methods."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test get_config
config = processor.get_config()
@@ -182,7 +178,7 @@ def test_serialization_methods():
def test_process_environment_state():
"""Test processing environment_state."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
@@ -203,7 +199,7 @@ def test_process_environment_state():
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
@@ -224,7 +220,7 @@ def test_process_agent_pos():
def test_process_batched_states():
"""Test processing already batched states."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
@@ -242,7 +238,7 @@ def test_process_batched_states():
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
@@ -267,7 +263,7 @@ def test_process_both_states():
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -359,17 +355,6 @@ def test_empty_observation():
assert processed_obs == {}
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = StateProcessor()
processor = VanillaObservationProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = StateProcessor()
proc = VanillaObservationProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),