refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes
This commit is contained in:
Steven Palma
2025-08-06 14:00:13 +02:00
committed by GitHub
parent 7beb040e8e
commit fd4ae3466b
8 changed files with 165 additions and 421 deletions
+1 -7
View File
@@ -16,11 +16,7 @@
from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import (
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from .observation_processor import VanillaObservationProcessor
from .pipeline import (
ActionProcessor,
DoneProcessor,
@@ -43,7 +39,6 @@ __all__ = [
"DoneProcessor",
"EnvTransition",
"IdentityProcessor",
"ImageProcessor",
"InfoProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
@@ -53,7 +48,6 @@ __all__ = [
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
+4 -2
View File
@@ -20,6 +20,7 @@ import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@dataclass
@@ -30,10 +31,11 @@ class DeviceProcessor:
specified device (CPU or GPU) before they are returned.
"""
device: str = "cpu"
device: torch.device = "cpu"
def __post_init__(self):
self.non_blocking = "cuda" in self.device
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
+1 -5
View File
@@ -220,7 +220,6 @@ class UnnormalizerProcessor:
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -230,10 +229,8 @@ class UnnormalizerProcessor:
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
eps: float = 1e-8,
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self):
# Handle deserialization from JSON config
@@ -308,7 +305,6 @@ class UnnormalizerProcessor:
def get_config(self) -> dict[str, Any]:
return {
"eps": self.eps,
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
+83 -193
View File
@@ -13,8 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from dataclasses import dataclass
import einops
import numpy as np
@@ -23,52 +22,27 @@ from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
@dataclass
class ImageProcessor:
"""Process image observations from environment format to policy format.
Converts images from:
- Channel-last (H, W, C) to channel-first (C, H, W)
- uint8 [0, 255] to float32 [0, 1]
- Adds batch dimension if needed
- Handles both single images and dictionaries of images
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
"""
Processes environment observations into the LeRobot format by handling both images and states.
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
if observation is None:
return transition
processed_obs = {}
# Copy all observations first
for key, value in observation.items():
processed_obs[key] = value
# Handle pixels key if present
pixels = observation.get("pixels")
if pixels is not None:
# Remove pixels from processed_obs since we'll replace it with processed images
processed_obs.pop("pixels", None)
# Determine image mapping
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
# Process each image
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Return new transition with processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
@@ -95,173 +69,89 @@ class ImageProcessor:
return img_tensor
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
pixels -> OBS_IMAGE,
observation.pixels -> OBS_IMAGE,
pixels.<cam> -> OBS_IMAGES.<cam>,
observation.pixels.<cam> -> OBS_IMAGES.<cam>
def _process_observation(self, observation):
"""
if "pixels" in features:
features[OBS_IMAGE] = features.pop("pixels")
if "observation.pixels" in features:
features[OBS_IMAGE] = features.pop("observation.pixels")
prefixes = ("pixels.", "observation.pixels.")
for key in list(features.keys()):
for p in prefixes:
if key.startswith(p):
suffix = key[len(p) :]
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
break
return features
@dataclass
class StateProcessor:
"""Process state observations from environment format to policy format.
Handles:
- environment_state -> observation.environment_state
- agent_pos -> observation.state
- Converts numpy arrays to tensors
- Adds batch dimension if needed
Processes both image and state observations.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
processed_obs = observation.copy()
if observation is None:
return transition
if "pixels" in processed_obs:
pixels = processed_obs.pop("pixels")
processed_obs = dict(observation) # Copy existing observations
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
# Process environment_state
if "environment_state" in observation:
env_state = torch.from_numpy(observation["environment_state"]).float()
for imgkey, img in imgs.items():
processed_obs[imgkey] = self._process_single_image(img)
if "environment_state" in processed_obs:
env_state_np = processed_obs.pop("environment_state")
env_state = torch.from_numpy(env_state_np).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs[OBS_ENV_STATE] = env_state
# Remove original key
del processed_obs["environment_state"]
# Process agent_pos
if "agent_pos" in observation:
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
if "agent_pos" in processed_obs:
agent_pos_np = processed_obs.pop("agent_pos")
agent_pos = torch.from_numpy(agent_pos_np).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs[OBS_STATE] = agent_pos
# Remove original key
del processed_obs["agent_pos"]
# Return new transition with processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
return processed_obs
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def observation(self, observation):
return self._process_observation(observation)
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
environment_state -> OBS_ENV_STATE,
agent_pos -> OBS_STATE,
observation.environment_state -> OBS_ENV_STATE,
observation.agent_pos -> OBS_STATE
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
"""
pairs = (
("environment_state", OBS_ENV_STATE),
("agent_pos", OBS_STATE),
)
for old, new in pairs:
if old in features:
features[new] = features.pop(old)
prefixed = f"observation.{old}"
if prefixed in features:
features[new] = features.pop(prefixed)
return features
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor:
"""Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation
function but in a modular, composable way that fits into the pipeline architecture.
"""
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
state_processor: StateProcessor = field(default_factory=StateProcessor)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# First process images
transition = self.image_processor(transition)
# Then process state
transition = self.state_processor(transition)
return transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {
"image_processor": self.image_processor.get_config(),
"state_processor": self.state_processor.get_config(),
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary."""
state = {}
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
return state
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
}
state_state = {
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
def reset(self) -> None:
"""Reset processor state."""
self.image_processor.reset()
self.state_processor.reset()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features = self.image_processor.feature_contract(features)
features = self.state_processor.feature_contract(features)
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
"agent_pos": OBS_STATE,
}
prefix_pairs = {
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
break
return features
+42 -89
View File
@@ -36,6 +36,7 @@ from lerobot.configs.types import PolicyFeature
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
# TODO(Steven): Use consts
OBSERVATION = "observation"
ACTION = "action"
REWARD = "reward"
@@ -45,19 +46,18 @@ class TransitionKey(str, Enum):
COMPLEMENTARY_DATA = "complementary_data"
class EnvTransition(TypedDict, total=False):
"""Environment transition data structure.
All fields are optional (total=False) to allow flexible usage.
"""
observation: dict[str, Any] | None
action: Any | torch.Tensor | None
reward: float | torch.Tensor | None
done: bool | torch.Tensor | None
truncated: bool | torch.Tensor | None
info: dict[str, Any] | None
complementary_data: dict[str, Any] | None
EnvTransition = TypedDict(
"EnvTransition",
{
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.ACTION.value: Any | torch.Tensor | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None,
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
TransitionKey.INFO.value: dict[str, Any] | None,
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
},
)
class ProcessorStepRegistry:
@@ -135,8 +135,8 @@ class ProcessorStepRegistry:
class ProcessorStep(Protocol):
"""Structural typing interface for a single processor step.
A step is any callable accepting a full `EnvTransition` tuple and
returning a (possibly modified) tuple of the same structure. Implementers
A step is any callable accepting a full `EnvTransition` dict and
returning a (possibly modified) dict of the same structure. Implementers
are encouragedbut not requiredto expose the optional helper methods
listed below. When present, these hooks let `RobotProcessor`
automatically serialise the step's configuration and learnable state using
@@ -254,24 +254,22 @@ class RobotProcessor(ModelHubMixin):
Composable, debuggable post-processing processor for robot transitions.
The class orchestrates an ordered collection of small, functional transformsstepsexecuted
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
and batch dictionaries, automatically converting between formats as needed.
Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor".
seed: Global seed forwarded to steps that choose to consume it. Defaults to None.
to_transition: Function to convert batch dict to EnvTransition tuple.
to_transition: Function to convert batch dict to EnvTransition dict.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition tuple to the desired output format.
Usually it is a batch dict or EnvTransition tuple.
to_output: Function to convert EnvTransition dict to the desired output format.
Usually it is a batch dict or EnvTransition dict.
Defaults to _default_transition_to_batch.
before_step_hooks: List of hooks called before each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
reset_hooks: List of hooks called during processor reset.
Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to
@@ -290,7 +288,6 @@ class RobotProcessor(ModelHubMixin):
steps: Sequence[ProcessorStep] = field(default_factory=list)
name: str = "RobotProcessor"
seed: int | None = None
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
@@ -303,7 +300,6 @@ class RobotProcessor(ModelHubMixin):
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps.
@@ -431,7 +427,6 @@ class RobotProcessor(ModelHubMixin):
config: dict[str, Any] = {
"name": self.name,
"seed": self.seed,
"steps": [],
}
@@ -728,7 +723,7 @@ class RobotProcessor(ModelHubMixin):
f"Make sure override keys match exact step class names or registry names."
)
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
return cls(steps, loaded_config.get("name", "RobotProcessor"))
def __len__(self) -> int:
"""Return the number of steps in the processor."""
@@ -740,7 +735,7 @@ class RobotProcessor(ModelHubMixin):
* ``slice`` returns a new RobotProcessor with the sliced steps.
"""
if isinstance(idx, slice):
return RobotProcessor(self.steps[idx], self.name, self.seed)
return RobotProcessor(self.steps[idx], self.name)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
@@ -783,71 +778,11 @@ class RobotProcessor(ModelHubMixin):
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
) from None
def register_reset_hook(self, fn: Callable[[], None]):
"""Attach fn to be executed when reset is called."""
self.reset_hooks.append(fn)
def unregister_reset_hook(self, fn: Callable[[], None]):
"""Remove a previously registered reset hook.
Args:
fn: The exact function reference that was registered. Must be the same object.
Raises:
ValueError: If the hook is not found in the registered hooks.
"""
try:
self.reset_hooks.remove(fn)
except ValueError:
raise ValueError(
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
) from None
def reset(self):
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
for step in self.steps:
if hasattr(step, "reset"):
step.reset() # type: ignore[attr-defined]
for fn in self.reset_hooks:
fn()
def profile_steps(
self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5
) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import copy
import time
profile_results = {}
# Make a copy to avoid altering the original transition
transition_copy = copy.deepcopy(transition)
# Get intermediate transitions for each step using step_through
intermediate_transitions = list(self.step_through(transition_copy))
for idx, processor_step in enumerate(self.steps):
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
# Use the appropriate input transition for this step
input_transition = intermediate_transitions[idx]
# Warm up - copy transition for each run to ensure consistent conditions
for _ in range(warmup_runs):
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
# Time the step - copy transition for each run to ensure consistent conditions
start_time = time.perf_counter()
for _ in range(num_runs):
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
profile_results[step_name] = avg_time
return profile_results
def __repr__(self) -> str:
"""Return a readable string representation of the processor."""
@@ -864,9 +799,6 @@ class RobotProcessor(ModelHubMixin):
parts = [f"name='{self.name}'", steps_repr]
if self.seed is not None:
parts.append(f"seed={self.seed}")
return f"RobotProcessor({', '.join(parts)})"
def __post_init__(self):
@@ -931,6 +863,9 @@ class ObservationProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
processed_observation = self.observation(observation)
# Create a new transition dict with the processed observation
new_transition = transition.copy()
@@ -988,6 +923,9 @@ class ActionProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
processed_action = self.action(action)
# Create a new transition dict with the processed action
new_transition = transition.copy()
@@ -1044,6 +982,9 @@ class RewardProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
reward = transition.get(TransitionKey.REWARD)
if reward is None:
return transition
processed_reward = self.reward(reward)
# Create a new transition dict with the processed reward
new_transition = transition.copy()
@@ -1105,6 +1046,9 @@ class DoneProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
done = transition.get(TransitionKey.DONE)
if done is None:
return transition
processed_done = self.done(done)
# Create a new transition dict with the processed done flag
new_transition = transition.copy()
@@ -1162,6 +1106,9 @@ class TruncatedProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
processed_truncated = self.truncated(truncated)
# Create a new transition dict with the processed truncated flag
new_transition = transition.copy()
@@ -1224,6 +1171,9 @@ class InfoProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition.get(TransitionKey.INFO)
if info is None:
return transition
processed_info = self.info(info)
# Create a new transition dict with the processed info
new_transition = transition.copy()
@@ -1267,6 +1217,9 @@ class ComplementaryDataProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition:
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return transition
processed_complementary_data = self.complementary_data(complementary_data)
# Create a new transition dict with the processed complementary data
new_transition = transition.copy()
+7 -19
View File
@@ -16,24 +16,21 @@
from dataclasses import dataclass, field
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor:
class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
def observation(self, observation):
processed_obs = {}
for key, value in observation.items():
if key in self.rename_map:
@@ -41,20 +38,11 @@ class RenameProcessor:
else:
processed_obs[key] = value
# Create a new transition with the renamed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
return processed_obs
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
+19 -34
View File
@@ -20,11 +20,7 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import (
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor import VanillaObservationProcessor
from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed
@@ -46,7 +42,7 @@ def create_transition(
def test_process_single_image():
"""Test processing a single image."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -72,7 +68,7 @@ def test_process_single_image():
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -95,7 +91,7 @@ def test_process_image_dict():
def test_process_batched_image():
"""Test processing already batched images."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
@@ -112,7 +108,7 @@ def test_process_batched_image():
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
@@ -125,7 +121,7 @@ def test_invalid_image_format():
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
transition = create_transition()
result = processor(transition)
@@ -163,7 +159,7 @@ def test_none_observation():
def test_serialization_methods():
"""Test serialization methods."""
processor = ImageProcessor()
processor = VanillaObservationProcessor()
# Test get_config
config = processor.get_config()
@@ -182,7 +178,7 @@ def test_serialization_methods():
def test_process_environment_state():
"""Test processing environment_state."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
@@ -203,7 +199,7 @@ def test_process_environment_state():
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
@@ -224,7 +220,7 @@ def test_process_agent_pos():
def test_process_batched_states():
"""Test processing already batched states."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
@@ -242,7 +238,7 @@ def test_process_batched_states():
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
@@ -267,7 +263,7 @@ def test_process_both_states():
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = StateProcessor()
processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation)
@@ -359,17 +355,6 @@ def test_empty_observation():
assert processed_obs == {}
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = ImageProcessor()
processor = VanillaObservationProcessor()
features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = StateProcessor()
processor = VanillaObservationProcessor()
features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = StateProcessor()
proc = VanillaObservationProcessor()
features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
+8 -72
View File
@@ -363,32 +363,6 @@ def test_hooks():
assert after_calls == [0]
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
reset_called = []
def reset_hook():
reset_called.append(True)
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = create_transition()
pipeline(transition)
pipeline(transition)
assert step.counter == 2
# Reset should reset step and call hook
pipeline.reset()
assert step.counter == 0
assert len(reset_called) == 1
def test_unregister_hooks():
"""Test unregistering hooks from the pipeline."""
step = MockStep("test_step")
@@ -428,21 +402,6 @@ def test_unregister_hooks():
pipeline(transition)
assert len(after_calls) == 0
# Test reset_hook
reset_calls = []
def reset_hook():
reset_calls.append(True)
pipeline.register_reset_hook(reset_hook)
pipeline.reset()
assert len(reset_calls) == 1
pipeline.unregister_reset_hook(reset_hook)
reset_calls.clear()
pipeline.reset()
assert len(reset_calls) == 0
def test_unregister_nonexistent_hook():
"""Test error handling when unregistering hooks that don't exist."""
@@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook():
with pytest.raises(ValueError, match="not found in after_step_hooks"):
pipeline.unregister_after_step_hook(some_hook)
with pytest.raises(ValueError, match="not found in reset_hooks"):
pipeline.unregister_reset_hook(reset_hook)
def test_multiple_hooks_and_selective_unregister():
"""Test registering multiple hooks and selectively unregistering them."""
@@ -552,22 +508,6 @@ def test_hook_execution_order_documentation():
assert execution_order == ["A", "C", "B"] # B is now last
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = create_transition()
profile_results = pipeline.profile_steps(transition, num_runs=10)
assert len(profile_results) == 2
assert "step_0_MockStep" in profile_results
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained():
"""Test saving and loading pipeline.
@@ -581,7 +521,7 @@ def test_save_and_load_pretrained():
step1.counter = 5
step2.counter = 10
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42)
pipeline = RobotProcessor([step1, step2], name="TestPipeline")
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
@@ -596,7 +536,6 @@ def test_save_and_load_pretrained():
config = json.load(f)
assert config["name"] == "TestPipeline"
assert config["seed"] == 42
assert len(config["steps"]) == 2
# Verify counters are saved in config, not in separate state files
@@ -607,7 +546,6 @@ def test_save_and_load_pretrained():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestPipeline"
assert loaded_pipeline.seed == 42
assert len(loaded_pipeline) == 2
# Check that counter was restored from config
@@ -1255,10 +1193,10 @@ def test_repr_with_custom_name():
def test_repr_with_seed():
"""Test __repr__ with seed parameter."""
step = MockStep("test_step")
pipeline = RobotProcessor([step], seed=42)
pipeline = RobotProcessor([step])
repr_str = repr(pipeline)
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)"
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])"
assert repr_str == expected
@@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed():
"""Test __repr__ with both custom name and seed."""
step1 = MockStep("step1")
step2 = MockStepWithoutOptionalMethods()
pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123)
pipeline = RobotProcessor([step1, step2], name="MyProcessor")
repr_str = repr(pipeline)
expected = (
"RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)"
)
expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])"
assert repr_str == expected
def test_repr_without_seed():
"""Test __repr__ when seed is explicitly None (should not show seed)."""
step = MockStep("test_step")
pipeline = RobotProcessor([step], name="TestProcessor", seed=None)
pipeline = RobotProcessor([step], name="TestProcessor")
repr_str = repr(pipeline)
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
@@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names():
step3 = MockStepWithTensorState()
step4 = MockNonModuleStepWithState()
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999)
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames")
repr_str = repr(pipeline)
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])"
assert repr_str == expected