mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 15:39:56 +00:00
refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes
This commit is contained in:
@@ -16,11 +16,7 @@
|
||||
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
@@ -43,7 +39,6 @@ __all__ = [
|
||||
"DoneProcessor",
|
||||
"EnvTransition",
|
||||
"IdentityProcessor",
|
||||
"ImageProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
@@ -53,7 +48,6 @@ __all__ = [
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,10 +31,11 @@ class DeviceProcessor:
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
"""
|
||||
|
||||
device: str = "cpu"
|
||||
device: torch.device = "cpu"
|
||||
|
||||
def __post_init__(self):
|
||||
self.non_blocking = "cuda" in self.device
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
|
||||
@@ -220,7 +220,6 @@ class UnnormalizerProcessor:
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
eps: float = 1e-8
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@@ -230,10 +229,8 @@ class UnnormalizerProcessor:
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
eps: float = 1e-8,
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
@@ -308,7 +305,6 @@ class UnnormalizerProcessor:
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"eps": self.eps,
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
|
||||
@@ -13,8 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -23,52 +22,27 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageProcessor:
|
||||
"""Process image observations from environment format to policy format.
|
||||
|
||||
Converts images from:
|
||||
- Channel-last (H, W, C) to channel-first (C, H, W)
|
||||
- uint8 [0, 255] to float32 [0, 1]
|
||||
- Adds batch dimension if needed
|
||||
- Handles both single images and dictionaries of images
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
Image processing:
|
||||
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||
- Adds a batch dimension if missing
|
||||
- Supports single images and image dictionaries
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
processed_obs = {}
|
||||
|
||||
# Copy all observations first
|
||||
for key, value in observation.items():
|
||||
processed_obs[key] = value
|
||||
|
||||
# Handle pixels key if present
|
||||
pixels = observation.get("pixels")
|
||||
if pixels is not None:
|
||||
# Remove pixels from processed_obs since we'll replace it with processed images
|
||||
processed_obs.pop("pixels", None)
|
||||
# Determine image mapping
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
# Process each image
|
||||
for imgkey, img in imgs.items():
|
||||
processed_img = self._process_single_image(img)
|
||||
processed_obs[imgkey] = processed_img
|
||||
|
||||
# Return new transition with processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
State processing:
|
||||
- Maps 'environment_state' to observation.environment_state
|
||||
- Maps 'agent_pos' to observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds a batch dimension if missing
|
||||
"""
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
@@ -95,173 +69,89 @@ class ImageProcessor:
|
||||
|
||||
return img_tensor
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
pixels -> OBS_IMAGE,
|
||||
observation.pixels -> OBS_IMAGE,
|
||||
pixels.<cam> -> OBS_IMAGES.<cam>,
|
||||
observation.pixels.<cam> -> OBS_IMAGES.<cam>
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
if "pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("pixels")
|
||||
if "observation.pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("observation.pixels")
|
||||
|
||||
prefixes = ("pixels.", "observation.pixels.")
|
||||
for key in list(features.keys()):
|
||||
for p in prefixes:
|
||||
if key.startswith(p):
|
||||
suffix = key[len(p) :]
|
||||
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
|
||||
break
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateProcessor:
|
||||
"""Process state observations from environment format to policy format.
|
||||
|
||||
Handles:
|
||||
- environment_state -> observation.environment_state
|
||||
- agent_pos -> observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds batch dimension if needed
|
||||
Processes both image and state observations.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
processed_obs = observation.copy()
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
if "pixels" in processed_obs:
|
||||
pixels = processed_obs.pop("pixels")
|
||||
|
||||
processed_obs = dict(observation) # Copy existing observations
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
# Process environment_state
|
||||
if "environment_state" in observation:
|
||||
env_state = torch.from_numpy(observation["environment_state"]).float()
|
||||
for imgkey, img in imgs.items():
|
||||
processed_obs[imgkey] = self._process_single_image(img)
|
||||
|
||||
if "environment_state" in processed_obs:
|
||||
env_state_np = processed_obs.pop("environment_state")
|
||||
env_state = torch.from_numpy(env_state_np).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
processed_obs[OBS_ENV_STATE] = env_state
|
||||
# Remove original key
|
||||
del processed_obs["environment_state"]
|
||||
|
||||
# Process agent_pos
|
||||
if "agent_pos" in observation:
|
||||
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
|
||||
if "agent_pos" in processed_obs:
|
||||
agent_pos_np = processed_obs.pop("agent_pos")
|
||||
agent_pos = torch.from_numpy(agent_pos_np).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
processed_obs[OBS_STATE] = agent_pos
|
||||
# Remove original key
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
# Return new transition with processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
environment_state -> OBS_ENV_STATE,
|
||||
agent_pos -> OBS_STATE,
|
||||
observation.environment_state -> OBS_ENV_STATE,
|
||||
observation.agent_pos -> OBS_STATE
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- environment_state -> OBS_ENV_STATE,
|
||||
- agent_pos -> OBS_STATE,
|
||||
- observation.environment_state -> OBS_ENV_STATE,
|
||||
- observation.agent_pos -> OBS_STATE
|
||||
"""
|
||||
pairs = (
|
||||
("environment_state", OBS_ENV_STATE),
|
||||
("agent_pos", OBS_STATE),
|
||||
)
|
||||
for old, new in pairs:
|
||||
if old in features:
|
||||
features[new] = features.pop(old)
|
||||
prefixed = f"observation.{old}"
|
||||
if prefixed in features:
|
||||
features[new] = features.pop(prefixed)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor:
|
||||
"""Complete observation processor that combines image and state processing.
|
||||
|
||||
This processor replicates the functionality of the original preprocess_observation
|
||||
function but in a modular, composable way that fits into the pipeline architecture.
|
||||
"""
|
||||
|
||||
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
|
||||
state_processor: StateProcessor = field(default_factory=StateProcessor)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# First process images
|
||||
transition = self.image_processor(transition)
|
||||
# Then process state
|
||||
transition = self.state_processor(transition)
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {
|
||||
"image_processor": self.image_processor.get_config(),
|
||||
"state_processor": self.state_processor.get_config(),
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary."""
|
||||
state = {}
|
||||
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
|
||||
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary."""
|
||||
image_state = {
|
||||
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
|
||||
}
|
||||
state_state = {
|
||||
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
|
||||
}
|
||||
|
||||
self.image_processor.load_state_dict(image_state)
|
||||
self.state_processor.load_state_dict(state_state)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state."""
|
||||
self.image_processor.reset()
|
||||
self.state_processor.reset()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features = self.image_processor.feature_contract(features)
|
||||
features = self.state_processor.feature_contract(features)
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"agent_pos": OBS_STATE,
|
||||
}
|
||||
|
||||
prefix_pairs = {
|
||||
"pixels.": f"{OBS_IMAGES}.",
|
||||
}
|
||||
|
||||
for key in list(features.keys()):
|
||||
matched_prefix = False
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if matched_prefix:
|
||||
continue
|
||||
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key in features:
|
||||
features[new] = features.pop(key)
|
||||
break
|
||||
|
||||
return features
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.configs.types import PolicyFeature
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
@@ -45,19 +46,18 @@ class TransitionKey(str, Enum):
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
class EnvTransition(TypedDict, total=False):
|
||||
"""Environment transition data structure.
|
||||
|
||||
All fields are optional (total=False) to allow flexible usage.
|
||||
"""
|
||||
|
||||
observation: dict[str, Any] | None
|
||||
action: Any | torch.Tensor | None
|
||||
reward: float | torch.Tensor | None
|
||||
done: bool | torch.Tensor | None
|
||||
truncated: bool | torch.Tensor | None
|
||||
info: dict[str, Any] | None
|
||||
complementary_data: dict[str, Any] | None
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ProcessorStepRegistry:
|
||||
@@ -135,8 +135,8 @@ class ProcessorStepRegistry:
|
||||
class ProcessorStep(Protocol):
|
||||
"""Structural typing interface for a single processor step.
|
||||
|
||||
A step is any callable accepting a full `EnvTransition` tuple and
|
||||
returning a (possibly modified) tuple of the same structure. Implementers
|
||||
A step is any callable accepting a full `EnvTransition` dict and
|
||||
returning a (possibly modified) dict of the same structure. Implementers
|
||||
are encouraged—but not required—to expose the optional helper methods
|
||||
listed below. When present, these hooks let `RobotProcessor`
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
@@ -254,24 +254,22 @@ class RobotProcessor(ModelHubMixin):
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
|
||||
The class orchestrates an ordered collection of small, functional transforms—steps—executed
|
||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples
|
||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
|
||||
and batch dictionaries, automatically converting between formats as needed.
|
||||
|
||||
Args:
|
||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||
name: Human-readable identifier that is persisted inside the JSON config.
|
||||
Defaults to "RobotProcessor".
|
||||
seed: Global seed forwarded to steps that choose to consume it. Defaults to None.
|
||||
to_transition: Function to convert batch dict to EnvTransition tuple.
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition tuple to the desired output format.
|
||||
Usually it is a batch dict or EnvTransition tuple.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format.
|
||||
Usually it is a batch dict or EnvTransition dict.
|
||||
Defaults to _default_transition_to_batch.
|
||||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
reset_hooks: List of hooks called during processor reset.
|
||||
|
||||
Hook Semantics:
|
||||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||
@@ -290,7 +288,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||
name: str = "RobotProcessor"
|
||||
seed: int | None = None
|
||||
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||
@@ -303,7 +300,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
"""Process data through all steps.
|
||||
@@ -431,7 +427,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"seed": self.seed,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
@@ -728,7 +723,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Make sure override keys match exact step class names or registry names."
|
||||
)
|
||||
|
||||
return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed"))
|
||||
return cls(steps, loaded_config.get("name", "RobotProcessor"))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the processor."""
|
||||
@@ -740,7 +735,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return RobotProcessor(self.steps[idx], self.name, self.seed)
|
||||
return RobotProcessor(self.steps[idx], self.name)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
@@ -783,71 +778,11 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
||||
) from None
|
||||
|
||||
def register_reset_hook(self, fn: Callable[[], None]):
|
||||
"""Attach fn to be executed when reset is called."""
|
||||
self.reset_hooks.append(fn)
|
||||
|
||||
def unregister_reset_hook(self, fn: Callable[[], None]):
|
||||
"""Remove a previously registered reset hook.
|
||||
|
||||
Args:
|
||||
fn: The exact function reference that was registered. Must be the same object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the hook is not found in the registered hooks.
|
||||
"""
|
||||
try:
|
||||
self.reset_hooks.remove(fn)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
|
||||
) from None
|
||||
|
||||
def reset(self):
|
||||
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
||||
for step in self.steps:
|
||||
if hasattr(step, "reset"):
|
||||
step.reset() # type: ignore[attr-defined]
|
||||
for fn in self.reset_hooks:
|
||||
fn()
|
||||
|
||||
def profile_steps(
|
||||
self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5
|
||||
) -> dict[str, float]:
|
||||
"""Profile the execution time of each step for performance optimization."""
|
||||
import copy
|
||||
import time
|
||||
|
||||
profile_results = {}
|
||||
|
||||
# Make a copy to avoid altering the original transition
|
||||
transition_copy = copy.deepcopy(transition)
|
||||
|
||||
# Get intermediate transitions for each step using step_through
|
||||
intermediate_transitions = list(self.step_through(transition_copy))
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
|
||||
|
||||
# Use the appropriate input transition for this step
|
||||
input_transition = intermediate_transitions[idx]
|
||||
|
||||
# Warm up - copy transition for each run to ensure consistent conditions
|
||||
for _ in range(warmup_runs):
|
||||
transition_copy = copy.deepcopy(input_transition)
|
||||
_ = processor_step(transition_copy)
|
||||
|
||||
# Time the step - copy transition for each run to ensure consistent conditions
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
transition_copy = copy.deepcopy(input_transition)
|
||||
_ = processor_step(transition_copy)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
||||
profile_results[step_name] = avg_time
|
||||
|
||||
return profile_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a readable string representation of the processor."""
|
||||
@@ -864,9 +799,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
parts = [f"name='{self.name}'", steps_repr]
|
||||
|
||||
if self.seed is not None:
|
||||
parts.append(f"seed={self.seed}")
|
||||
|
||||
return f"RobotProcessor({', '.join(parts)})"
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -931,6 +863,9 @@ class ObservationProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
processed_observation = self.observation(observation)
|
||||
# Create a new transition dict with the processed observation
|
||||
new_transition = transition.copy()
|
||||
@@ -988,6 +923,9 @@ class ActionProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
processed_action = self.action(action)
|
||||
# Create a new transition dict with the processed action
|
||||
new_transition = transition.copy()
|
||||
@@ -1044,6 +982,9 @@ class RewardProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is None:
|
||||
return transition
|
||||
|
||||
processed_reward = self.reward(reward)
|
||||
# Create a new transition dict with the processed reward
|
||||
new_transition = transition.copy()
|
||||
@@ -1105,6 +1046,9 @@ class DoneProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is None:
|
||||
return transition
|
||||
|
||||
processed_done = self.done(done)
|
||||
# Create a new transition dict with the processed done flag
|
||||
new_transition = transition.copy()
|
||||
@@ -1162,6 +1106,9 @@ class TruncatedProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
processed_truncated = self.truncated(truncated)
|
||||
# Create a new transition dict with the processed truncated flag
|
||||
new_transition = transition.copy()
|
||||
@@ -1224,6 +1171,9 @@ class InfoProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO)
|
||||
if info is None:
|
||||
return transition
|
||||
|
||||
processed_info = self.info(info)
|
||||
# Create a new transition dict with the processed info
|
||||
new_transition = transition.copy()
|
||||
@@ -1267,6 +1217,9 @@ class ComplementaryDataProcessor:
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return transition
|
||||
|
||||
processed_complementary_data = self.complementary_data(complementary_data)
|
||||
# Create a new transition dict with the processed complementary data
|
||||
new_transition = transition.copy()
|
||||
|
||||
@@ -16,24 +16,21 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessor:
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
def observation(self, observation):
|
||||
processed_obs = {}
|
||||
for key, value in observation.items():
|
||||
if key in self.rename_map:
|
||||
@@ -41,20 +38,11 @@ class RenameProcessor:
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
|
||||
# Create a new transition with the renamed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
|
||||
@@ -20,11 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
VanillaObservationProcessor,
|
||||
)
|
||||
from lerobot.processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
@@ -46,7 +42,7 @@ def create_transition(
|
||||
|
||||
def test_process_single_image():
|
||||
"""Test processing a single image."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -72,7 +68,7 @@ def test_process_single_image():
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create mock images
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -95,7 +91,7 @@ def test_process_image_dict():
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a batched image (B, H, W, C)
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
@@ -112,7 +108,7 @@ def test_process_batched_image():
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
@@ -125,7 +121,7 @@ def test_invalid_image_format():
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
@@ -138,7 +134,7 @@ def test_invalid_image_dtype():
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -153,7 +149,7 @@ def test_no_pixels_in_observation():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
@@ -163,7 +159,7 @@ def test_none_observation():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -182,7 +178,7 @@ def test_serialization_methods():
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
@@ -203,7 +199,7 @@ def test_process_environment_state():
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -224,7 +220,7 @@ def test_process_agent_pos():
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
@@ -242,7 +238,7 @@ def test_process_batched_states():
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
@@ -267,7 +263,7 @@ def test_process_both_states():
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -359,17 +355,6 @@ def test_empty_observation():
|
||||
assert processed_obs == {}
|
||||
|
||||
|
||||
def test_custom_sub_processors():
|
||||
"""Test ObservationProcessor with custom sub-processors."""
|
||||
image_proc = ImageProcessor()
|
||||
state_proc = StateProcessor()
|
||||
processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc)
|
||||
|
||||
# Should use the provided processors
|
||||
assert processor.image_processor is image_proc
|
||||
assert processor.state_processor is state_proc
|
||||
|
||||
|
||||
def test_equivalent_to_original_function():
|
||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||
# Import the original function for comparison
|
||||
@@ -426,7 +411,7 @@ def test_equivalent_with_image_dict():
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
@@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
@@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = ImageProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
@@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = StateProcessor()
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
@@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||
proc = StateProcessor()
|
||||
proc = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
|
||||
@@ -363,32 +363,6 @@ def test_hooks():
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
reset_called = []
|
||||
|
||||
def reset_hook():
|
||||
reset_called.append(True)
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
|
||||
# Make some calls to increment counter
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
pipeline(transition)
|
||||
|
||||
assert step.counter == 2
|
||||
|
||||
# Reset should reset step and call hook
|
||||
pipeline.reset()
|
||||
|
||||
assert step.counter == 0
|
||||
assert len(reset_called) == 1
|
||||
|
||||
|
||||
def test_unregister_hooks():
|
||||
"""Test unregistering hooks from the pipeline."""
|
||||
step = MockStep("test_step")
|
||||
@@ -428,21 +402,6 @@ def test_unregister_hooks():
|
||||
pipeline(transition)
|
||||
assert len(after_calls) == 0
|
||||
|
||||
# Test reset_hook
|
||||
reset_calls = []
|
||||
|
||||
def reset_hook():
|
||||
reset_calls.append(True)
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 1
|
||||
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
reset_calls.clear()
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 0
|
||||
|
||||
|
||||
def test_unregister_nonexistent_hook():
|
||||
"""Test error handling when unregistering hooks that don't exist."""
|
||||
@@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook():
|
||||
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
||||
pipeline.unregister_after_step_hook(some_hook)
|
||||
|
||||
with pytest.raises(ValueError, match="not found in reset_hooks"):
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
|
||||
|
||||
def test_multiple_hooks_and_selective_unregister():
|
||||
"""Test registering multiple hooks and selectively unregistering them."""
|
||||
@@ -552,22 +508,6 @@ def test_hook_execution_order_documentation():
|
||||
assert execution_order == ["A", "C", "B"] # B is now last
|
||||
|
||||
|
||||
def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
transition = create_transition()
|
||||
|
||||
profile_results = pipeline.profile_steps(transition, num_runs=10)
|
||||
|
||||
assert len(profile_results) == 2
|
||||
assert "step_0_MockStep" in profile_results
|
||||
assert "step_1_MockStep" in profile_results
|
||||
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading pipeline.
|
||||
|
||||
@@ -581,7 +521,7 @@ def test_save_and_load_pretrained():
|
||||
step1.counter = 5
|
||||
step2.counter = 10
|
||||
|
||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42)
|
||||
pipeline = RobotProcessor([step1, step2], name="TestPipeline")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
@@ -596,7 +536,6 @@ def test_save_and_load_pretrained():
|
||||
config = json.load(f)
|
||||
|
||||
assert config["name"] == "TestPipeline"
|
||||
assert config["seed"] == 42
|
||||
assert len(config["steps"]) == 2
|
||||
|
||||
# Verify counters are saved in config, not in separate state files
|
||||
@@ -607,7 +546,6 @@ def test_save_and_load_pretrained():
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "TestPipeline"
|
||||
assert loaded_pipeline.seed == 42
|
||||
assert len(loaded_pipeline) == 2
|
||||
|
||||
# Check that counter was restored from config
|
||||
@@ -1255,10 +1193,10 @@ def test_repr_with_custom_name():
|
||||
def test_repr_with_seed():
|
||||
"""Test __repr__ with seed parameter."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step], seed=42)
|
||||
pipeline = RobotProcessor([step])
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)"
|
||||
expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
@@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed():
|
||||
"""Test __repr__ with both custom name and seed."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStepWithoutOptionalMethods()
|
||||
pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123)
|
||||
pipeline = RobotProcessor([step1, step2], name="MyProcessor")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = (
|
||||
"RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)"
|
||||
)
|
||||
expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
def test_repr_without_seed():
|
||||
"""Test __repr__ when seed is explicitly None (should not show seed)."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step], name="TestProcessor", seed=None)
|
||||
pipeline = RobotProcessor([step], name="TestProcessor")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])"
|
||||
@@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names():
|
||||
step3 = MockStepWithTensorState()
|
||||
step4 = MockNonModuleStepWithState()
|
||||
|
||||
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999)
|
||||
pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames")
|
||||
repr_str = repr(pipeline)
|
||||
|
||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)"
|
||||
expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])"
|
||||
assert repr_str == expected
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user