refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes
This commit is contained in:
Steven Palma
2025-08-06 14:00:13 +02:00
committed by GitHub
parent 7beb040e8e
commit fd4ae3466b
8 changed files with 165 additions and 421 deletions
+1 -7
View File
@@ -16,11 +16,7 @@
from .device_processor import DeviceProcessor from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import ( from .observation_processor import VanillaObservationProcessor
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from .pipeline import ( from .pipeline import (
ActionProcessor, ActionProcessor,
DoneProcessor, DoneProcessor,
@@ -43,7 +39,6 @@ __all__ = [
"DoneProcessor", "DoneProcessor",
"EnvTransition", "EnvTransition",
"IdentityProcessor", "IdentityProcessor",
"ImageProcessor",
"InfoProcessor", "InfoProcessor",
"NormalizerProcessor", "NormalizerProcessor",
"UnnormalizerProcessor", "UnnormalizerProcessor",
@@ -53,7 +48,6 @@ __all__ = [
"RenameProcessor", "RenameProcessor",
"RewardProcessor", "RewardProcessor",
"RobotProcessor", "RobotProcessor",
"StateProcessor",
"TransitionKey", "TransitionKey",
"TruncatedProcessor", "TruncatedProcessor",
"VanillaObservationProcessor", "VanillaObservationProcessor",
+4 -2
View File
@@ -20,6 +20,7 @@ import torch
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@dataclass @dataclass
@@ -30,10 +31,11 @@ class DeviceProcessor:
specified device (CPU or GPU) before they are returned. specified device (CPU or GPU) before they are returned.
""" """
device: str = "cpu" device: torch.device = "cpu"
def __post_init__(self): def __post_init__(self):
self.non_blocking = "cuda" in self.device self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition # Create a copy of the transition
+1 -5
View File
@@ -220,7 +220,6 @@ class UnnormalizerProcessor:
features: dict[str, PolicyFeature] features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode] norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None stats: dict[str, dict[str, Any]] | None = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -230,10 +229,8 @@ class UnnormalizerProcessor:
dataset: LeRobotDataset, dataset: LeRobotDataset,
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode], norm_map: dict[FeatureType, NormalizationMode],
*,
eps: float = 1e-8,
) -> UnnormalizerProcessor: ) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps) return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self): def __post_init__(self):
# Handle deserialization from JSON config # Handle deserialization from JSON config
@@ -308,7 +305,6 @@ class UnnormalizerProcessor:
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
return { return {
"eps": self.eps,
"features": { "features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
}, },
+83 -193
View File
@@ -13,8 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any
import einops import einops
import numpy as np import numpy as np
@@ -23,52 +22,27 @@ from torch import Tensor
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
@dataclass @dataclass
class ImageProcessor: @ProcessorStepRegistry.register(name="observation_processor")
"""Process image observations from environment format to policy format. class VanillaObservationProcessor(ObservationProcessor):
Converts images from:
- Channel-last (H, W, C) to channel-first (C, H, W)
- uint8 [0, 255] to float32 [0, 1]
- Adds batch dimension if needed
- Handles both single images and dictionaries of images
""" """
Processes environment observations into the LeRobot format by handling both images and states.
def __call__(self, transition: EnvTransition) -> EnvTransition: Image processing:
observation = transition.get(TransitionKey.OBSERVATION) - Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
if observation is None: State processing:
return transition - Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
processed_obs = {} - Converts numpy arrays to tensors
- Adds a batch dimension if missing
# Copy all observations first """
for key, value in observation.items():
processed_obs[key] = value
# Handle pixels key if present
pixels = observation.get("pixels")
if pixels is not None:
# Remove pixels from processed_obs since we'll replace it with processed images
processed_obs.pop("pixels", None)
# Determine image mapping
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
# Process each image
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Return new transition with processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def _process_single_image(self, img: np.ndarray) -> Tensor: def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array.""" """Process a single image array."""
@@ -95,173 +69,89 @@ class ImageProcessor:
return img_tensor return img_tensor
def get_config(self) -> dict[str, Any]: def _process_observation(self, observation):
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
pixels -> OBS_IMAGE,
observation.pixels -> OBS_IMAGE,
pixels.<cam> -> OBS_IMAGES.<cam>,
observation.pixels.<cam> -> OBS_IMAGES.<cam>
""" """
if "pixels" in features: Processes both image and state observations.
features[OBS_IMAGE] = features.pop("pixels")
if "observation.pixels" in features:
features[OBS_IMAGE] = features.pop("observation.pixels")
prefixes = ("pixels.", "observation.pixels.")
for key in list(features.keys()):
for p in prefixes:
if key.startswith(p):
suffix = key[len(p) :]
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
break
return features
@dataclass
class StateProcessor:
"""Process state observations from environment format to policy format.
Handles:
- environment_state -> observation.environment_state
- agent_pos -> observation.state
- Converts numpy arrays to tensors
- Adds batch dimension if needed
""" """
def __call__(self, transition: EnvTransition) -> EnvTransition: processed_obs = observation.copy()
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None: if "pixels" in processed_obs:
return transition pixels = processed_obs.pop("pixels")
processed_obs = dict(observation) # Copy existing observations if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
# Process environment_state for imgkey, img in imgs.items():
if "environment_state" in observation: processed_obs[imgkey] = self._process_single_image(img)
env_state = torch.from_numpy(observation["environment_state"]).float()
if "environment_state" in processed_obs:
env_state_np = processed_obs.pop("environment_state")
env_state = torch.from_numpy(env_state_np).float()
if env_state.dim() == 1: if env_state.dim() == 1:
env_state = env_state.unsqueeze(0) env_state = env_state.unsqueeze(0)
processed_obs[OBS_ENV_STATE] = env_state processed_obs[OBS_ENV_STATE] = env_state
# Remove original key
del processed_obs["environment_state"]
# Process agent_pos if "agent_pos" in processed_obs:
if "agent_pos" in observation: agent_pos_np = processed_obs.pop("agent_pos")
agent_pos = torch.from_numpy(observation["agent_pos"]).float() agent_pos = torch.from_numpy(agent_pos_np).float()
if agent_pos.dim() == 1: if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0) agent_pos = agent_pos.unsqueeze(0)
processed_obs[OBS_STATE] = agent_pos processed_obs[OBS_STATE] = agent_pos
# Remove original key
del processed_obs["agent_pos"]
# Return new transition with processed observation return processed_obs
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]: def observation(self, observation):
"""Return configuration for serialization.""" return self._process_observation(observation)
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms: """Transforms feature keys to a standardized contract.
environment_state -> OBS_ENV_STATE,
agent_pos -> OBS_STATE, This method handles several renaming patterns:
observation.environment_state -> OBS_ENV_STATE, - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
observation.agent_pos -> OBS_STATE - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
""" """
pairs = ( exact_pairs = {
("environment_state", OBS_ENV_STATE), "pixels": OBS_IMAGE,
("agent_pos", OBS_STATE), "environment_state": OBS_ENV_STATE,
) "agent_pos": OBS_STATE,
for old, new in pairs: }
if old in features:
features[new] = features.pop(old) prefix_pairs = {
prefixed = f"observation.{old}" "pixels.": f"{OBS_IMAGES}.",
if prefixed in features: }
features[new] = features.pop(prefixed)
return features for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
@dataclass prefixed_old = f"observation.{old_prefix}"
@ProcessorStepRegistry.register(name="observation_processor") if key.startswith(prefixed_old):
class VanillaObservationProcessor: suffix = key[len(prefixed_old) :]
"""Complete observation processor that combines image and state processing. features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
This processor replicates the functionality of the original preprocess_observation break
function but in a modular, composable way that fits into the pipeline architecture.
""" if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
image_processor: ImageProcessor = field(default_factory=ImageProcessor) features[f"{new_prefix}{suffix}"] = features.pop(key)
state_processor: StateProcessor = field(default_factory=StateProcessor) matched_prefix = True
break
def __call__(self, transition: EnvTransition) -> EnvTransition:
# First process images if matched_prefix:
transition = self.image_processor(transition) continue
# Then process state
transition = self.state_processor(transition) for old, new in exact_pairs.items():
return transition if key == old or key == f"observation.{old}":
if key in features:
def get_config(self) -> dict[str, Any]: features[new] = features.pop(key)
"""Return configuration for serialization.""" break
return {
"image_processor": self.image_processor.get_config(),
"state_processor": self.state_processor.get_config(),
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary."""
state = {}
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
return state
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
}
state_state = {
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
def reset(self) -> None:
"""Reset processor state."""
self.image_processor.reset()
self.state_processor.reset()
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features = self.image_processor.feature_contract(features)
features = self.state_processor.feature_contract(features)
return features return features
+42 -89
View File
@@ -36,6 +36,7 @@ from lerobot.configs.types import PolicyFeature
class TransitionKey(str, Enum): class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components.""" """Keys for accessing EnvTransition dictionary components."""
# TODO(Steven): Use consts
OBSERVATION = "observation" OBSERVATION = "observation"
ACTION = "action" ACTION = "action"
REWARD = "reward" REWARD = "reward"
@@ -45,19 +46,18 @@ class TransitionKey(str, Enum):
COMPLEMENTARY_DATA = "complementary_data" COMPLEMENTARY_DATA = "complementary_data"
class EnvTransition(TypedDict, total=False): EnvTransition = TypedDict(
"""Environment transition data structure. "EnvTransition",
{
All fields are optional (total=False) to allow flexible usage. TransitionKey.OBSERVATION.value: dict[str, Any] | None,
""" TransitionKey.ACTION.value: Any | torch.Tensor | None,
TransitionKey.REWARD.value: float | torch.Tensor | None,
observation: dict[str, Any] | None TransitionKey.DONE.value: bool | torch.Tensor | None,
action: Any | torch.Tensor | None TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
reward: float | torch.Tensor | None TransitionKey.INFO.value: dict[str, Any] | None,
done: bool | torch.Tensor | None TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
truncated: bool | torch.Tensor | None },
info: dict[str, Any] | None )
complementary_data: dict[str, Any] | None
class ProcessorStepRegistry: class ProcessorStepRegistry:
@@ -135,8 +135,8 @@ class ProcessorStepRegistry:
class ProcessorStep(Protocol): class ProcessorStep(Protocol):
"""Structural typing interface for a single processor step. """Structural typing interface for a single processor step.
A step is any callable accepting a full `EnvTransition` tuple and A step is any callable accepting a full `EnvTransition` dict and
returning a (possibly modified) tuple of the same structure. Implementers returning a (possibly modified) dict of the same structure. Implementers
are encouraged—but not required—to expose the optional helper methods are encouraged—but not required—to expose the optional helper methods
listed below. When present, these hooks let `RobotProcessor` listed below. When present, these hooks let `RobotProcessor`
automatically serialise the step's configuration and learnable state using automatically serialise the step's configuration and learnable state using
@@ -254,24 +254,22 @@ class RobotProcessor(ModelHubMixin):
Composable, debuggable post-processing processor for robot transitions. Composable, debuggable post-processing processor for robot transitions.
The class orchestrates an ordered collection of small, functional transforms—steps—executed The class orchestrates an ordered collection of small, functional transforms—steps—executed
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
and batch dictionaries, automatically converting between formats as needed. and batch dictionaries, automatically converting between formats as needed.
Args: Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list. steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config. name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor". Defaults to "RobotProcessor".
seed: Global seed forwarded to steps that choose to consume it. Defaults to None. to_transition: Function to convert batch dict to EnvTransition dict.
to_transition: Function to convert batch dict to EnvTransition tuple.
Defaults to _default_batch_to_transition. Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition tuple to the desired output format. to_output: Function to convert EnvTransition dict to the desired output format.
Usually it is a batch dict or EnvTransition tuple. Usually it is a batch dict or EnvTransition dict.
Defaults to _default_transition_to_batch. Defaults to _default_transition_to_batch.
before_step_hooks: List of hooks called before each step. Each hook receives the step before_step_hooks: List of hooks called before each step. Each hook receives the step
index and transition, and can optionally return a modified transition. index and transition, and can optionally return a modified transition.
after_step_hooks: List of hooks called after each step. Each hook receives the step after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition. index and transition, and can optionally return a modified transition.
reset_hooks: List of hooks called during processor reset.
Hook Semantics: Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to - Hooks are executed sequentially in the order they were registered. There is no way to
@@ -290,7 +288,6 @@ class RobotProcessor(ModelHubMixin):
steps: Sequence[ProcessorStep] = field(default_factory=list) steps: Sequence[ProcessorStep] = field(default_factory=list)
name: str = "RobotProcessor" name: str = "RobotProcessor"
seed: int | None = None
to_transition: Callable[[dict[str, Any]], EnvTransition] = field( to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False default_factory=lambda: _default_batch_to_transition, repr=False
@@ -303,7 +300,6 @@ class RobotProcessor(ModelHubMixin):
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]): def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process data through all steps. """Process data through all steps.
@@ -431,7 +427,6 @@ class RobotProcessor(ModelHubMixin):
config: dict[str, Any] = { config: dict[str, Any] = {
"name": self.name, "name": self.name,
"seed": self.seed,
"steps": [], "steps": [],
} }
@@ -728,7 +723,7 @@ class RobotProcessor(ModelHubMixin):
f"Make sure override keys match exact step class names or registry names." f"Make sure override keys match exact step class names or registry names."
) )
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed")) return cls(steps, loaded_config.get("name", "RobotProcessor"))
def __len__(self) -> int: def __len__(self) -> int:
"""Return the number of steps in the processor.""" """Return the number of steps in the processor."""
@@ -740,7 +735,7 @@ class RobotProcessor(ModelHubMixin):
* ``slice`` returns a new RobotProcessor with the sliced steps. * ``slice`` returns a new RobotProcessor with the sliced steps.
""" """
if isinstance(idx, slice): if isinstance(idx, slice):
return RobotProcessor(self.steps[idx], self.name, self.seed) return RobotProcessor(self.steps[idx], self.name)
return self.steps[idx] return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
@@ -783,71 +778,11 @@ class RobotProcessor(ModelHubMixin):
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
) from None ) from None
def register_reset_hook(self, fn: Callable[[], None]):
"""Attach fn to be executed when reset is called."""
self.reset_hooks.append(fn)
def unregister_reset_hook(self, fn: Callable[[], None]):
"""Remove a previously registered reset hook.
Args:
fn: The exact function reference that was registered. Must be the same object.
Raises:
ValueError: If the hook is not found in the registered hooks.
"""
try:
self.reset_hooks.remove(fn)
except ValueError:
raise ValueError(
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
) from None
def reset(self): def reset(self):
"""Clear state in every step that implements ``reset()`` and fire registered hooks.""" """Clear state in every step that implements ``reset()`` and fire registered hooks."""
for step in self.steps: for step in self.steps:
if hasattr(step, "reset"): if hasattr(step, "reset"):
step.reset() # type: ignore[attr-defined] step.reset() # type: ignore[attr-defined]
for fn in self.reset_hooks:
fn()
def profile_steps(
self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5
) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import copy
import time
profile_results = {}
# Make a copy to avoid altering the original transition
transition_copy = copy.deepcopy(transition)
# Get intermediate transitions for each step using step_through
intermediate_transitions = list(self.step_through(transition_copy))
for idx, processor_step in enumerate(self.steps):
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
# Use the appropriate input transition for this step
input_transition = intermediate_transitions[idx]
# Warm up - copy transition for each run to ensure consistent conditions
for _ in range(warmup_runs):
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
# Time the step - copy transition for each run to ensure consistent conditions
start_time = time.perf_counter()
for _ in range(num_runs):
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
profile_results[step_name] = avg_time
return profile_results
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a readable string representation of the processor.""" """Return a readable string representation of the processor."""
@@ -864,9 +799,6 @@ class RobotProcessor(ModelHubMixin):
parts = [f"name='{self.name}'", steps_repr] parts = [f"name='{self.name}'", steps_repr]
if self.seed is not None:
parts.append(f"seed={self.seed}")
return f"RobotProcessor({', '.join(parts)})" return f"RobotProcessor({', '.join(parts)})"
def __post_init__(self): def __post_init__(self):
@@ -931,6 +863,9 @@ class ObservationProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION) observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
processed_observation = self.observation(observation) processed_observation = self.observation(observation)
# Create a new transition dict with the processed observation # Create a new transition dict with the processed observation
new_transition = transition.copy() new_transition = transition.copy()
@@ -988,6 +923,9 @@ class ActionProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION) action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
processed_action = self.action(action) processed_action = self.action(action)
# Create a new transition dict with the processed action # Create a new transition dict with the processed action
new_transition = transition.copy() new_transition = transition.copy()
@@ -1044,6 +982,9 @@ class RewardProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
reward = transition.get(TransitionKey.REWARD) reward = transition.get(TransitionKey.REWARD)
if reward is None:
return transition
processed_reward = self.reward(reward) processed_reward = self.reward(reward)
# Create a new transition dict with the processed reward # Create a new transition dict with the processed reward
new_transition = transition.copy() new_transition = transition.copy()
@@ -1105,6 +1046,9 @@ class DoneProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
done = transition.get(TransitionKey.DONE) done = transition.get(TransitionKey.DONE)
if done is None:
return transition
processed_done = self.done(done) processed_done = self.done(done)
# Create a new transition dict with the processed done flag # Create a new transition dict with the processed done flag
new_transition = transition.copy() new_transition = transition.copy()
@@ -1162,6 +1106,9 @@ class TruncatedProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED) truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
processed_truncated = self.truncated(truncated) processed_truncated = self.truncated(truncated)
# Create a new transition dict with the processed truncated flag # Create a new transition dict with the processed truncated flag
new_transition = transition.copy() new_transition = transition.copy()
@@ -1224,6 +1171,9 @@ class InfoProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
info = transition.get(TransitionKey.INFO) info = transition.get(TransitionKey.INFO)
if info is None:
return transition
processed_info = self.info(info) processed_info = self.info(info)
# Create a new transition dict with the processed info # Create a new transition dict with the processed info
new_transition = transition.copy() new_transition = transition.copy()
@@ -1267,6 +1217,9 @@ class ComplementaryDataProcessor:
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return transition
processed_complementary_data = self.complementary_data(complementary_data) processed_complementary_data = self.complementary_data(complementary_data)
# Create a new transition dict with the processed complementary data # Create a new transition dict with the processed complementary data
new_transition = transition.copy() new_transition = transition.copy()
+7 -19
View File
@@ -16,24 +16,21 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
import torch
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
@dataclass @dataclass
@ProcessorStepRegistry.register(name="rename_processor") @ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor: class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation.""" """Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict) rename_map: dict[str, str] = field(default_factory=dict)
def __call__(self, transition: EnvTransition) -> EnvTransition: def observation(self, observation):
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
processed_obs = {} processed_obs = {}
for key, value in observation.items(): for key, value in observation.items():
if key in self.rename_map: if key in self.rename_map:
@@ -41,20 +38,11 @@ class RenameProcessor:
else: else:
processed_obs[key] = value processed_obs[key] = value
# Create a new transition with the renamed observation return processed_obs
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
return new_transition
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map} return {"rename_map": self.rename_map}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms: """Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value. - Each key in the observation that appears in `rename_map` is renamed to its value.
+19 -34
View File
@@ -20,11 +20,7 @@ import torch
from lerobot.configs.types import FeatureType from lerobot.configs.types import FeatureType
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor import ( from lerobot.processor import VanillaObservationProcessor
ImageProcessor,
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor.pipeline import TransitionKey from lerobot.processor.pipeline import TransitionKey
from tests.conftest import assert_contract_is_typed from tests.conftest import assert_contract_is_typed
@@ -46,7 +42,7 @@ def create_transition(
def test_process_single_image(): def test_process_single_image():
"""Test processing a single image.""" """Test processing a single image."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Create a mock image (H, W, C) format, uint8 # Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
@@ -72,7 +68,7 @@ def test_process_single_image():
def test_process_image_dict(): def test_process_image_dict():
"""Test processing multiple images in a dictionary.""" """Test processing multiple images in a dictionary."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Create mock images # Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
@@ -95,7 +91,7 @@ def test_process_image_dict():
def test_process_batched_image(): def test_process_batched_image():
"""Test processing already batched images.""" """Test processing already batched images."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Create a batched image (B, H, W, C) # Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
@@ -112,7 +108,7 @@ def test_process_batched_image():
def test_invalid_image_format(): def test_invalid_image_format():
"""Test error handling for invalid image formats.""" """Test error handling for invalid image formats."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Test wrong channel order (channels first) # Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
@@ -125,7 +121,7 @@ def test_invalid_image_format():
def test_invalid_image_dtype(): def test_invalid_image_dtype():
"""Test error handling for invalid image dtype.""" """Test error handling for invalid image dtype."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Test wrong dtype # Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32) image = np.random.rand(64, 64, 3).astype(np.float32)
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
def test_no_pixels_in_observation(): def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation.""" """Test processor when no pixels are in observation."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])} observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation) transition = create_transition(observation=observation)
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
def test_none_observation(): def test_none_observation():
"""Test processor with None observation.""" """Test processor with None observation."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
transition = create_transition() transition = create_transition()
result = processor(transition) result = processor(transition)
@@ -163,7 +159,7 @@ def test_none_observation():
def test_serialization_methods(): def test_serialization_methods():
"""Test serialization methods.""" """Test serialization methods."""
processor = ImageProcessor() processor = VanillaObservationProcessor()
# Test get_config # Test get_config
config = processor.get_config() config = processor.get_config()
@@ -182,7 +178,7 @@ def test_serialization_methods():
def test_process_environment_state(): def test_process_environment_state():
"""Test processing environment_state.""" """Test processing environment_state."""
processor = StateProcessor() processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state} observation = {"environment_state": env_state}
@@ -203,7 +199,7 @@ def test_process_environment_state():
def test_process_agent_pos(): def test_process_agent_pos():
"""Test processing agent_pos.""" """Test processing agent_pos."""
processor = StateProcessor() processor = VanillaObservationProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos} observation = {"agent_pos": agent_pos}
@@ -224,7 +220,7 @@ def test_process_agent_pos():
def test_process_batched_states(): def test_process_batched_states():
"""Test processing already batched states.""" """Test processing already batched states."""
processor = StateProcessor() processor = VanillaObservationProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
@@ -242,7 +238,7 @@ def test_process_batched_states():
def test_process_both_states(): def test_process_both_states():
"""Test processing both environment_state and agent_pos.""" """Test processing both environment_state and agent_pos."""
processor = StateProcessor() processor = VanillaObservationProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32) env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32) agent_pos = np.array([0.5, -0.5], dtype=np.float32)
@@ -267,7 +263,7 @@ def test_process_both_states():
def test_no_states_in_observation(): def test_no_states_in_observation():
"""Test processor when no states are in observation.""" """Test processor when no states are in observation."""
processor = StateProcessor() processor = VanillaObservationProcessor()
observation = {"other_data": np.array([1, 2, 3])} observation = {"other_data": np.array([1, 2, 3])}
transition = create_transition(observation=observation) transition = create_transition(observation=observation)
@@ -359,17 +355,6 @@ def test_empty_observation():
assert processed_obs == {} assert processed_obs == {}
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function(): def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation.""" """Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison # Import the original function for comparison
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
processor = ImageProcessor() processor = VanillaObservationProcessor()
features = { features = {
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)), "keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
processor = ImageProcessor() processor = VanillaObservationProcessor()
features = { features = {
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"keep": policy_feature_factory(FeatureType.ENV, (1,)), "keep": policy_feature_factory(FeatureType.ENV, (1,)),
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
processor = ImageProcessor() processor = VanillaObservationProcessor()
features = { features = {
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
processor = StateProcessor() processor = VanillaObservationProcessor()
features = { features = {
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)), "environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
proc = StateProcessor() proc = VanillaObservationProcessor()
features = { features = {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
+8 -72
View File
@@ -363,32 +363,6 @@ def test_hooks():
assert after_calls == [0] assert after_calls == [0]
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
reset_called = []
def reset_hook():
reset_called.append(True)
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = create_transition()
pipeline(transition)
pipeline(transition)
assert step.counter == 2
# Reset should reset step and call hook
pipeline.reset()
assert step.counter == 0
assert len(reset_called) == 1
def test_unregister_hooks(): def test_unregister_hooks():
"""Test unregistering hooks from the pipeline.""" """Test unregistering hooks from the pipeline."""
step = MockStep("test_step") step = MockStep("test_step")
@@ -428,21 +402,6 @@ def test_unregister_hooks():
pipeline(transition) pipeline(transition)
assert len(after_calls) == 0 assert len(after_calls) == 0
# Test reset_hook
reset_calls = []
def reset_hook():
reset_calls.append(True)
pipeline.register_reset_hook(reset_hook)
pipeline.reset()
assert len(reset_calls) == 1
pipeline.unregister_reset_hook(reset_hook)
reset_calls.clear()
pipeline.reset()
assert len(reset_calls) == 0
def test_unregister_nonexistent_hook(): def test_unregister_nonexistent_hook():
"""Test error handling when unregistering hooks that don't exist.""" """Test error handling when unregistering hooks that don't exist."""
@@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook():
with pytest.raises(ValueError, match="not found in after_step_hooks"): with pytest.raises(ValueError, match="not found in after_step_hooks"):
pipeline.unregister_after_step_hook(some_hook) pipeline.unregister_after_step_hook(some_hook)
with pytest.raises(ValueError, match="not found in reset_hooks"):
pipeline.unregister_reset_hook(reset_hook)
def test_multiple_hooks_and_selective_unregister(): def test_multiple_hooks_and_selective_unregister():
"""Test registering multiple hooks and selectively unregistering them.""" """Test registering multiple hooks and selectively unregistering them."""
@@ -552,22 +508,6 @@ def test_hook_execution_order_documentation():
assert execution_order == ["A", "C", "B"] # B is now last assert execution_order == ["A", "C", "B"] # B is now last
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = create_transition()
profile_results = pipeline.profile_steps(transition, num_runs=10)
assert len(profile_results) == 2
assert "step_0_MockStep" in profile_results
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained(): def test_save_and_load_pretrained():
"""Test saving and loading pipeline. """Test saving and loading pipeline.
@@ -581,7 +521,7 @@ def test_save_and_load_pretrained():
step1.counter = 5 step1.counter = 5
step2.counter = 10 step2.counter = 10
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42) pipeline = RobotProcessor([step1, step2], name="TestPipeline")
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline # Save pipeline
@@ -596,7 +536,6 @@ def test_save_and_load_pretrained():
config = json.load(f) config = json.load(f)
assert config["name"] == "TestPipeline" assert config["name"] == "TestPipeline"
assert config["seed"] == 42
assert len(config["steps"]) == 2 assert len(config["steps"]) == 2
# Verify counters are saved in config, not in separate state files # Verify counters are saved in config, not in separate state files
@@ -607,7 +546,6 @@ def test_save_and_load_pretrained():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestPipeline" assert loaded_pipeline.name == "TestPipeline"
assert loaded_pipeline.seed == 42
assert len(loaded_pipeline) == 2 assert len(loaded_pipeline) == 2
# Check that counter was restored from config # Check that counter was restored from config
@@ -1255,10 +1193,10 @@ def test_repr_with_custom_name():
def test_repr_with_seed(): def test_repr_with_seed():
"""Test __repr__ with seed parameter.""" """Test __repr__ with seed parameter."""
step = MockStep("test_step") step = MockStep("test_step")
pipeline = RobotProcessor([step], seed=42) pipeline = RobotProcessor([step])
repr_str = repr(pipeline) repr_str = repr(pipeline)
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)" expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])"
assert repr_str == expected assert repr_str == expected
@@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed():
"""Test __repr__ with both custom name and seed.""" """Test __repr__ with both custom name and seed."""
step1 = MockStep("step1") step1 = MockStep("step1")
step2 = MockStepWithoutOptionalMethods() step2 = MockStepWithoutOptionalMethods()
pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123) pipeline = RobotProcessor([step1, step2], name="MyProcessor")
repr_str = repr(pipeline) repr_str = repr(pipeline)
expected = ( expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])"
"RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)"
)
assert repr_str == expected assert repr_str == expected
def test_repr_without_seed(): def test_repr_without_seed():
"""Test __repr__ when seed is explicitly None (should not show seed).""" """Test __repr__ when seed is explicitly None (should not show seed)."""
step = MockStep("test_step") step = MockStep("test_step")
pipeline = RobotProcessor([step], name="TestProcessor", seed=None) pipeline = RobotProcessor([step], name="TestProcessor")
repr_str = repr(pipeline) repr_str = repr(pipeline)
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
@@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names():
step3 = MockStepWithTensorState() step3 = MockStepWithTensorState()
step4 = MockNonModuleStepWithState() step4 = MockNonModuleStepWithState()
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999) pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames")
repr_str = repr(pipeline) repr_str = repr(pipeline)
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)" expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])"
assert repr_str == expected assert repr_str == expected