From fd4ae3466bf2279074608712052da0cefbdabaf9 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 6 Aug 2025 14:00:13 +0200 Subject: [PATCH] refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes --- src/lerobot/processor/__init__.py | 8 +- src/lerobot/processor/device_processor.py | 6 +- src/lerobot/processor/normalize_processor.py | 6 +- .../processor/observation_processor.py | 276 ++++++------------ src/lerobot/processor/pipeline.py | 131 +++------ src/lerobot/processor/rename_processor.py | 26 +- tests/processor/test_observation_processor.py | 53 ++-- tests/processor/test_pipeline.py | 80 +---- 8 files changed, 165 insertions(+), 421 deletions(-) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 0a5a5dd2c..8dd244c27 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -16,11 +16,7 @@ from .device_processor import DeviceProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor -from .observation_processor import ( - ImageProcessor, - StateProcessor, - VanillaObservationProcessor, -) +from .observation_processor import VanillaObservationProcessor from .pipeline import ( ActionProcessor, DoneProcessor, @@ -43,7 +39,6 @@ __all__ = [ "DoneProcessor", "EnvTransition", "IdentityProcessor", - "ImageProcessor", "InfoProcessor", "NormalizerProcessor", "UnnormalizerProcessor", @@ -53,7 +48,6 @@ __all__ = [ "RenameProcessor", "RewardProcessor", "RobotProcessor", - "StateProcessor", "TransitionKey", "TruncatedProcessor", "VanillaObservationProcessor", diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 8d7d04878..0f00bb470 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -20,6 +20,7 @@ import torch from lerobot.configs.types import PolicyFeature from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.utils.utils import get_safe_torch_device @dataclass @@ -30,10 +31,11 @@ class DeviceProcessor: specified device (CPU or GPU) before they are returned. """ - device: str = "cpu" + device: torch.device = "cpu" def __post_init__(self): - self.non_blocking = "cuda" in self.device + self.device = get_safe_torch_device(self.device) + self.non_blocking = "cuda" in str(self.device) def __call__(self, transition: EnvTransition) -> EnvTransition: # Create a copy of the transition diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index a8424013c..14628727f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -220,7 +220,6 @@ class UnnormalizerProcessor: features: dict[str, PolicyFeature] norm_map: dict[FeatureType, NormalizationMode] stats: dict[str, dict[str, Any]] | None = None - eps: float = 1e-8 _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) @@ -230,10 +229,8 @@ class UnnormalizerProcessor: dataset: LeRobotDataset, features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], - *, - eps: float = 1e-8, ) -> UnnormalizerProcessor: - return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps) + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) def __post_init__(self): # Handle deserialization from JSON config @@ -308,7 +305,6 @@ class UnnormalizerProcessor: def get_config(self) -> dict[str, Any]: return { - "eps": self.eps, "features": { key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() }, diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 091b1286d..7d63db238 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -13,8 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass import einops import numpy as np @@ -23,52 +22,27 @@ from torch import Tensor from lerobot.configs.types import PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry @dataclass -class ImageProcessor: - """Process image observations from environment format to policy format. - - Converts images from: - - Channel-last (H, W, C) to channel-first (C, H, W) - - uint8 [0, 255] to float32 [0, 1] - - Adds batch dimension if needed - - Handles both single images and dictionaries of images +@ProcessorStepRegistry.register(name="observation_processor") +class VanillaObservationProcessor(ObservationProcessor): """ + Processes environment observations into the LeRobot format by handling both images and states. - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) + Image processing: + - Converts channel-last (H, W, C) images to channel-first (C, H, W) + - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) + - Adds a batch dimension if missing + - Supports single images and image dictionaries - if observation is None: - return transition - - processed_obs = {} - - # Copy all observations first - for key, value in observation.items(): - processed_obs[key] = value - - # Handle pixels key if present - pixels = observation.get("pixels") - if pixels is not None: - # Remove pixels from processed_obs since we'll replace it with processed images - processed_obs.pop("pixels", None) - # Determine image mapping - if isinstance(pixels, dict): - imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} - else: - imgs = {OBS_IMAGE: pixels} - - # Process each image - for imgkey, img in imgs.items(): - processed_img = self._process_single_image(img) - processed_obs[imgkey] = processed_img - - # Return new transition with processed observation - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = processed_obs - return new_transition + State processing: + - Maps 'environment_state' to observation.environment_state + - Maps 'agent_pos' to observation.state + - Converts numpy arrays to tensors + - Adds a batch dimension if missing + """ def _process_single_image(self, img: np.ndarray) -> Tensor: """Process a single image array.""" @@ -95,173 +69,89 @@ class ImageProcessor: return img_tensor - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Transforms: - pixels -> OBS_IMAGE, - observation.pixels -> OBS_IMAGE, - pixels. -> OBS_IMAGES., - observation.pixels. -> OBS_IMAGES. + def _process_observation(self, observation): + """ + Processes both image and state observations. """ - if "pixels" in features: - features[OBS_IMAGE] = features.pop("pixels") - if "observation.pixels" in features: - features[OBS_IMAGE] = features.pop("observation.pixels") - prefixes = ("pixels.", "observation.pixels.") - for key in list(features.keys()): - for p in prefixes: - if key.startswith(p): - suffix = key[len(p) :] - features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key) - break - return features + processed_obs = observation.copy() + if "pixels" in processed_obs: + pixels = processed_obs.pop("pixels") -@dataclass -class StateProcessor: - """Process state observations from environment format to policy format. + if isinstance(pixels, dict): + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} + else: + imgs = {OBS_IMAGE: pixels} - Handles: - - environment_state -> observation.environment_state - - agent_pos -> observation.state - - Converts numpy arrays to tensors - - Adds batch dimension if needed - """ + for imgkey, img in imgs.items(): + processed_obs[imgkey] = self._process_single_image(img) - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - - if observation is None: - return transition - - processed_obs = dict(observation) # Copy existing observations - - # Process environment_state - if "environment_state" in observation: - env_state = torch.from_numpy(observation["environment_state"]).float() + if "environment_state" in processed_obs: + env_state_np = processed_obs.pop("environment_state") + env_state = torch.from_numpy(env_state_np).float() if env_state.dim() == 1: env_state = env_state.unsqueeze(0) processed_obs[OBS_ENV_STATE] = env_state - # Remove original key - del processed_obs["environment_state"] - # Process agent_pos - if "agent_pos" in observation: - agent_pos = torch.from_numpy(observation["agent_pos"]).float() + if "agent_pos" in processed_obs: + agent_pos_np = processed_obs.pop("agent_pos") + agent_pos = torch.from_numpy(agent_pos_np).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) processed_obs[OBS_STATE] = agent_pos - # Remove original key - del processed_obs["agent_pos"] - # Return new transition with processed observation - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = processed_obs - return new_transition + return processed_obs - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary (empty for this processor).""" - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary (no-op for this processor).""" - pass - - def reset(self) -> None: - """Reset processor state (no-op for this processor).""" - pass + def observation(self, observation): + return self._process_observation(observation) def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Transforms: - environment_state -> OBS_ENV_STATE, - agent_pos -> OBS_STATE, - observation.environment_state -> OBS_ENV_STATE, - observation.agent_pos -> OBS_STATE + """Transforms feature keys to a standardized contract. + + This method handles several renaming patterns: + - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). + - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). + - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). + - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). + - environment_state -> OBS_ENV_STATE, + - agent_pos -> OBS_STATE, + - observation.environment_state -> OBS_ENV_STATE, + - observation.agent_pos -> OBS_STATE """ - pairs = ( - ("environment_state", OBS_ENV_STATE), - ("agent_pos", OBS_STATE), - ) - for old, new in pairs: - if old in features: - features[new] = features.pop(old) - prefixed = f"observation.{old}" - if prefixed in features: - features[new] = features.pop(prefixed) - return features - - -@dataclass -@ProcessorStepRegistry.register(name="observation_processor") -class VanillaObservationProcessor: - """Complete observation processor that combines image and state processing. - - This processor replicates the functionality of the original preprocess_observation - function but in a modular, composable way that fits into the pipeline architecture. - """ - - image_processor: ImageProcessor = field(default_factory=ImageProcessor) - state_processor: StateProcessor = field(default_factory=StateProcessor) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - # First process images - transition = self.image_processor(transition) - # Then process state - transition = self.state_processor(transition) - return transition - - def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return { - "image_processor": self.image_processor.get_config(), - "state_processor": self.state_processor.get_config(), - } - - def state_dict(self) -> dict[str, torch.Tensor]: - """Return state dictionary.""" - state = {} - state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()}) - state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()}) - return state - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - """Load state dictionary.""" - image_state = { - k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.") - } - state_state = { - k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.") - } - - self.image_processor.load_state_dict(image_state) - self.state_processor.load_state_dict(state_state) - - def reset(self) -> None: - """Reset processor state.""" - self.image_processor.reset() - self.state_processor.reset() - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features = self.image_processor.feature_contract(features) - features = self.state_processor.feature_contract(features) + exact_pairs = { + "pixels": OBS_IMAGE, + "environment_state": OBS_ENV_STATE, + "agent_pos": OBS_STATE, + } + + prefix_pairs = { + "pixels.": f"{OBS_IMAGES}.", + } + + for key in list(features.keys()): + matched_prefix = False + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if matched_prefix: + continue + + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + if key in features: + features[new] = features.pop(key) + break + return features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 7e78830fb..6e1b2a2cb 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -36,6 +36,7 @@ from lerobot.configs.types import PolicyFeature class TransitionKey(str, Enum): """Keys for accessing EnvTransition dictionary components.""" + # TODO(Steven): Use consts OBSERVATION = "observation" ACTION = "action" REWARD = "reward" @@ -45,19 +46,18 @@ class TransitionKey(str, Enum): COMPLEMENTARY_DATA = "complementary_data" -class EnvTransition(TypedDict, total=False): - """Environment transition data structure. - - All fields are optional (total=False) to allow flexible usage. - """ - - observation: dict[str, Any] | None - action: Any | torch.Tensor | None - reward: float | torch.Tensor | None - done: bool | torch.Tensor | None - truncated: bool | torch.Tensor | None - info: dict[str, Any] | None - complementary_data: dict[str, Any] | None +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) class ProcessorStepRegistry: @@ -135,8 +135,8 @@ class ProcessorStepRegistry: class ProcessorStep(Protocol): """Structural typing interface for a single processor step. - A step is any callable accepting a full `EnvTransition` tuple and - returning a (possibly modified) tuple of the same structure. Implementers + A step is any callable accepting a full `EnvTransition` dict and + returning a (possibly modified) dict of the same structure. Implementers are encouraged—but not required—to expose the optional helper methods listed below. When present, these hooks let `RobotProcessor` automatically serialise the step's configuration and learnable state using @@ -254,24 +254,22 @@ class RobotProcessor(ModelHubMixin): Composable, debuggable post-processing processor for robot transitions. The class orchestrates an ordered collection of small, functional transforms—steps—executed - left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples + left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts and batch dictionaries, automatically converting between formats as needed. Args: steps: Ordered list of processing steps executed on every call. Defaults to empty list. name: Human-readable identifier that is persisted inside the JSON config. Defaults to "RobotProcessor". - seed: Global seed forwarded to steps that choose to consume it. Defaults to None. - to_transition: Function to convert batch dict to EnvTransition tuple. + to_transition: Function to convert batch dict to EnvTransition dict. Defaults to _default_batch_to_transition. - to_output: Function to convert EnvTransition tuple to the desired output format. - Usually it is a batch dict or EnvTransition tuple. + to_output: Function to convert EnvTransition dict to the desired output format. + Usually it is a batch dict or EnvTransition dict. Defaults to _default_transition_to_batch. before_step_hooks: List of hooks called before each step. Each hook receives the step index and transition, and can optionally return a modified transition. after_step_hooks: List of hooks called after each step. Each hook receives the step index and transition, and can optionally return a modified transition. - reset_hooks: List of hooks called during processor reset. Hook Semantics: - Hooks are executed sequentially in the order they were registered. There is no way to @@ -290,7 +288,6 @@ class RobotProcessor(ModelHubMixin): steps: Sequence[ProcessorStep] = field(default_factory=list) name: str = "RobotProcessor" - seed: int | None = None to_transition: Callable[[dict[str, Any]], EnvTransition] = field( default_factory=lambda: _default_batch_to_transition, repr=False @@ -303,7 +300,6 @@ class RobotProcessor(ModelHubMixin): # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) - reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) def __call__(self, data: EnvTransition | dict[str, Any]): """Process data through all steps. @@ -431,7 +427,6 @@ class RobotProcessor(ModelHubMixin): config: dict[str, Any] = { "name": self.name, - "seed": self.seed, "steps": [], } @@ -728,7 +723,7 @@ class RobotProcessor(ModelHubMixin): f"Make sure override keys match exact step class names or registry names." ) - return cls(steps, loaded_config.get("name", "RobotProcessor"), loaded_config.get("seed")) + return cls(steps, loaded_config.get("name", "RobotProcessor")) def __len__(self) -> int: """Return the number of steps in the processor.""" @@ -740,7 +735,7 @@ class RobotProcessor(ModelHubMixin): * ``slice`` – returns a new RobotProcessor with the sliced steps. """ if isinstance(idx, slice): - return RobotProcessor(self.steps[idx], self.name, self.seed) + return RobotProcessor(self.steps[idx], self.name) return self.steps[idx] def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): @@ -783,71 +778,11 @@ class RobotProcessor(ModelHubMixin): f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." ) from None - def register_reset_hook(self, fn: Callable[[], None]): - """Attach fn to be executed when reset is called.""" - self.reset_hooks.append(fn) - - def unregister_reset_hook(self, fn: Callable[[], None]): - """Remove a previously registered reset hook. - - Args: - fn: The exact function reference that was registered. Must be the same object. - - Raises: - ValueError: If the hook is not found in the registered hooks. - """ - try: - self.reset_hooks.remove(fn) - except ValueError: - raise ValueError( - f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference." - ) from None - def reset(self): """Clear state in every step that implements ``reset()`` and fire registered hooks.""" for step in self.steps: if hasattr(step, "reset"): step.reset() # type: ignore[attr-defined] - for fn in self.reset_hooks: - fn() - - def profile_steps( - self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5 - ) -> dict[str, float]: - """Profile the execution time of each step for performance optimization.""" - import copy - import time - - profile_results = {} - - # Make a copy to avoid altering the original transition - transition_copy = copy.deepcopy(transition) - - # Get intermediate transitions for each step using step_through - intermediate_transitions = list(self.step_through(transition_copy)) - - for idx, processor_step in enumerate(self.steps): - step_name = f"step_{idx}_{processor_step.__class__.__name__}" - - # Use the appropriate input transition for this step - input_transition = intermediate_transitions[idx] - - # Warm up - copy transition for each run to ensure consistent conditions - for _ in range(warmup_runs): - transition_copy = copy.deepcopy(input_transition) - _ = processor_step(transition_copy) - - # Time the step - copy transition for each run to ensure consistent conditions - start_time = time.perf_counter() - for _ in range(num_runs): - transition_copy = copy.deepcopy(input_transition) - _ = processor_step(transition_copy) - end_time = time.perf_counter() - - avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds - profile_results[step_name] = avg_time - - return profile_results def __repr__(self) -> str: """Return a readable string representation of the processor.""" @@ -864,9 +799,6 @@ class RobotProcessor(ModelHubMixin): parts = [f"name='{self.name}'", steps_repr] - if self.seed is not None: - parts.append(f"seed={self.seed}") - return f"RobotProcessor({', '.join(parts)})" def __post_init__(self): @@ -931,6 +863,9 @@ class ObservationProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + processed_observation = self.observation(observation) # Create a new transition dict with the processed observation new_transition = transition.copy() @@ -988,6 +923,9 @@ class ActionProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + processed_action = self.action(action) # Create a new transition dict with the processed action new_transition = transition.copy() @@ -1044,6 +982,9 @@ class RewardProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: reward = transition.get(TransitionKey.REWARD) + if reward is None: + return transition + processed_reward = self.reward(reward) # Create a new transition dict with the processed reward new_transition = transition.copy() @@ -1105,6 +1046,9 @@ class DoneProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: done = transition.get(TransitionKey.DONE) + if done is None: + return transition + processed_done = self.done(done) # Create a new transition dict with the processed done flag new_transition = transition.copy() @@ -1162,6 +1106,9 @@ class TruncatedProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is None: + return transition + processed_truncated = self.truncated(truncated) # Create a new transition dict with the processed truncated flag new_transition = transition.copy() @@ -1224,6 +1171,9 @@ class InfoProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: info = transition.get(TransitionKey.INFO) + if info is None: + return transition + processed_info = self.info(info) # Create a new transition dict with the processed info new_transition = transition.copy() @@ -1267,6 +1217,9 @@ class ComplementaryDataProcessor: def __call__(self, transition: EnvTransition) -> EnvTransition: complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return transition + processed_complementary_data = self.complementary_data(complementary_data) # Create a new transition dict with the processed complementary data new_transition = transition.copy() diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 7e1897541..4fe4105a5 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -16,24 +16,21 @@ from dataclasses import dataclass, field from typing import Any -import torch - from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.pipeline import ( + ObservationProcessor, + ProcessorStepRegistry, +) @dataclass @ProcessorStepRegistry.register(name="rename_processor") -class RenameProcessor: +class RenameProcessor(ObservationProcessor): """Rename processor that renames keys in the observation.""" rename_map: dict[str, str] = field(default_factory=dict) - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition - + def observation(self, observation): processed_obs = {} for key, value in observation.items(): if key in self.rename_map: @@ -41,20 +38,11 @@ class RenameProcessor: else: processed_obs[key] = value - # Create a new transition with the renamed observation - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = processed_obs - return new_transition + return processed_obs def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: """Transforms: - Each key in the observation that appears in `rename_map` is renamed to its value. diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index fb6a78155..e48b6bc08 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -20,11 +20,7 @@ import torch from lerobot.configs.types import FeatureType from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor import ( - ImageProcessor, - StateProcessor, - VanillaObservationProcessor, -) +from lerobot.processor import VanillaObservationProcessor from lerobot.processor.pipeline import TransitionKey from tests.conftest import assert_contract_is_typed @@ -46,7 +42,7 @@ def create_transition( def test_process_single_image(): """Test processing a single image.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Create a mock image (H, W, C) format, uint8 image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -72,7 +68,7 @@ def test_process_single_image(): def test_process_image_dict(): """Test processing multiple images in a dictionary.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Create mock images image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -95,7 +91,7 @@ def test_process_image_dict(): def test_process_batched_image(): """Test processing already batched images.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Create a batched image (B, H, W, C) image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) @@ -112,7 +108,7 @@ def test_process_batched_image(): def test_invalid_image_format(): """Test error handling for invalid image formats.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Test wrong channel order (channels first) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) @@ -125,7 +121,7 @@ def test_invalid_image_format(): def test_invalid_image_dtype(): """Test error handling for invalid image dtype.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Test wrong dtype image = np.random.rand(64, 64, 3).astype(np.float32) @@ -138,7 +134,7 @@ def test_invalid_image_dtype(): def test_no_pixels_in_observation(): """Test processor when no pixels are in observation.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -153,7 +149,7 @@ def test_no_pixels_in_observation(): def test_none_observation(): """Test processor with None observation.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() transition = create_transition() result = processor(transition) @@ -163,7 +159,7 @@ def test_none_observation(): def test_serialization_methods(): """Test serialization methods.""" - processor = ImageProcessor() + processor = VanillaObservationProcessor() # Test get_config config = processor.get_config() @@ -182,7 +178,7 @@ def test_serialization_methods(): def test_process_environment_state(): """Test processing environment_state.""" - processor = StateProcessor() + processor = VanillaObservationProcessor() env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) observation = {"environment_state": env_state} @@ -203,7 +199,7 @@ def test_process_environment_state(): def test_process_agent_pos(): """Test processing agent_pos.""" - processor = StateProcessor() + processor = VanillaObservationProcessor() agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -224,7 +220,7 @@ def test_process_agent_pos(): def test_process_batched_states(): """Test processing already batched states.""" - processor = StateProcessor() + processor = VanillaObservationProcessor() env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) @@ -242,7 +238,7 @@ def test_process_batched_states(): def test_process_both_states(): """Test processing both environment_state and agent_pos.""" - processor = StateProcessor() + processor = VanillaObservationProcessor() env_state = np.array([1.0, 2.0], dtype=np.float32) agent_pos = np.array([0.5, -0.5], dtype=np.float32) @@ -267,7 +263,7 @@ def test_process_both_states(): def test_no_states_in_observation(): """Test processor when no states are in observation.""" - processor = StateProcessor() + processor = VanillaObservationProcessor() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -359,17 +355,6 @@ def test_empty_observation(): assert processed_obs == {} -def test_custom_sub_processors(): - """Test ObservationProcessor with custom sub-processors.""" - image_proc = ImageProcessor() - state_proc = StateProcessor() - processor = VanillaObservationProcessor(image_processor=image_proc, state_processor=state_proc) - - # Should use the provided processors - assert processor.image_processor is image_proc - assert processor.state_processor is state_proc - - def test_equivalent_to_original_function(): """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" # Import the original function for comparison @@ -426,7 +411,7 @@ def test_equivalent_with_image_dict(): def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): - processor = ImageProcessor() + processor = VanillaObservationProcessor() features = { "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), @@ -440,7 +425,7 @@ def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): - processor = ImageProcessor() + processor = VanillaObservationProcessor() features = { "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "keep": policy_feature_factory(FeatureType.ENV, (1,)), @@ -454,7 +439,7 @@ def test_image_processor_feature_contract_observation_pixels_to_image(policy_fea def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): - processor = ImageProcessor() + processor = VanillaObservationProcessor() features = { "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), @@ -472,7 +457,7 @@ def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_featu def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): - processor = StateProcessor() + processor = VanillaObservationProcessor() features = { "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), @@ -488,7 +473,7 @@ def test_state_processor_feature_contract_environment_and_agent_pos(policy_featu def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): - proc = StateProcessor() + proc = VanillaObservationProcessor() features = { "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 0822b474d..5665d5a7d 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -363,32 +363,6 @@ def test_hooks(): assert after_calls == [0] -def test_reset(): - """Test pipeline reset functionality.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - reset_called = [] - - def reset_hook(): - reset_called.append(True) - - pipeline.register_reset_hook(reset_hook) - - # Make some calls to increment counter - transition = create_transition() - pipeline(transition) - pipeline(transition) - - assert step.counter == 2 - - # Reset should reset step and call hook - pipeline.reset() - - assert step.counter == 0 - assert len(reset_called) == 1 - - def test_unregister_hooks(): """Test unregistering hooks from the pipeline.""" step = MockStep("test_step") @@ -428,21 +402,6 @@ def test_unregister_hooks(): pipeline(transition) assert len(after_calls) == 0 - # Test reset_hook - reset_calls = [] - - def reset_hook(): - reset_calls.append(True) - - pipeline.register_reset_hook(reset_hook) - pipeline.reset() - assert len(reset_calls) == 1 - - pipeline.unregister_reset_hook(reset_hook) - reset_calls.clear() - pipeline.reset() - assert len(reset_calls) == 0 - def test_unregister_nonexistent_hook(): """Test error handling when unregistering hooks that don't exist.""" @@ -461,9 +420,6 @@ def test_unregister_nonexistent_hook(): with pytest.raises(ValueError, match="not found in after_step_hooks"): pipeline.unregister_after_step_hook(some_hook) - with pytest.raises(ValueError, match="not found in reset_hooks"): - pipeline.unregister_reset_hook(reset_hook) - def test_multiple_hooks_and_selective_unregister(): """Test registering multiple hooks and selectively unregistering them.""" @@ -552,22 +508,6 @@ def test_hook_execution_order_documentation(): assert execution_order == ["A", "C", "B"] # B is now last -def test_profile_steps(): - """Test step profiling functionality.""" - step1 = MockStep("step1") - step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) - - transition = create_transition() - - profile_results = pipeline.profile_steps(transition, num_runs=10) - - assert len(profile_results) == 2 - assert "step_0_MockStep" in profile_results - assert "step_1_MockStep" in profile_results - assert all(isinstance(time, float) and time >= 0 for time in profile_results.values()) - - def test_save_and_load_pretrained(): """Test saving and loading pipeline. @@ -581,7 +521,7 @@ def test_save_and_load_pretrained(): step1.counter = 5 step2.counter = 10 - pipeline = RobotProcessor([step1, step2], name="TestPipeline", seed=42) + pipeline = RobotProcessor([step1, step2], name="TestPipeline") with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline @@ -596,7 +536,6 @@ def test_save_and_load_pretrained(): config = json.load(f) assert config["name"] == "TestPipeline" - assert config["seed"] == 42 assert len(config["steps"]) == 2 # Verify counters are saved in config, not in separate state files @@ -607,7 +546,6 @@ def test_save_and_load_pretrained(): loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) assert loaded_pipeline.name == "TestPipeline" - assert loaded_pipeline.seed == 42 assert len(loaded_pipeline) == 2 # Check that counter was restored from config @@ -1255,10 +1193,10 @@ def test_repr_with_custom_name(): def test_repr_with_seed(): """Test __repr__ with seed parameter.""" step = MockStep("test_step") - pipeline = RobotProcessor([step], seed=42) + pipeline = RobotProcessor([step]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep], seed=42)" + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" assert repr_str == expected @@ -1266,19 +1204,17 @@ def test_repr_with_custom_name_and_seed(): """Test __repr__ with both custom name and seed.""" step1 = MockStep("step1") step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2], name="MyProcessor", seed=123) + pipeline = RobotProcessor([step1, step2], name="MyProcessor") repr_str = repr(pipeline) - expected = ( - "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods], seed=123)" - ) + expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" assert repr_str == expected def test_repr_without_seed(): """Test __repr__ when seed is explicitly None (should not show seed).""" step = MockStep("test_step") - pipeline = RobotProcessor([step], name="TestProcessor", seed=None) + pipeline = RobotProcessor([step], name="TestProcessor") repr_str = repr(pipeline) expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" @@ -1306,10 +1242,10 @@ def test_repr_edge_case_long_names(): step3 = MockStepWithTensorState() step4 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames", seed=999) + pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") repr_str = repr(pipeline) - expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState], seed=999)" + expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" assert repr_str == expected