From 9183083e75e12fdcb276d4ed858c4eba23c257f7 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 10 Sep 2025 22:40:37 +0200 Subject: [PATCH] refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908) * refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma --- examples/phone_to_so100/evaluate.py | 4 +- examples/phone_to_so100/record.py | 8 +- examples/phone_to_so100/replay.py | 6 +- examples/phone_to_so100/teleoperate.py | 6 +- src/lerobot/processor/__init__.py | 4 + src/lerobot/processor/batch_processor.py | 10 +- src/lerobot/processor/converters.py | 16 +- src/lerobot/processor/core.py | 10 +- .../processor/delta_action_processor.py | 14 +- src/lerobot/processor/device_processor.py | 6 +- src/lerobot/processor/gym_action_processor.py | 14 +- src/lerobot/processor/hil_processor.py | 6 +- src/lerobot/processor/normalize_processor.py | 21 ++- src/lerobot/processor/pipeline.py | 82 ++++++++- src/lerobot/record.py | 10 +- src/lerobot/replay.py | 6 +- .../robot_kinematic_processor.py | 18 +- src/lerobot/teleoperate.py | 10 +- tests/processor/test_batch_conversion.py | 16 +- tests/processor/test_batch_processor.py | 163 +++++++++++------- tests/processor/test_normalize_processor.py | 4 +- tests/processor/test_sac_processor.py | 8 +- 22 files changed, 303 insertions(+), 139 deletions(-) diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index fc25a0acd..de8dd9073 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -25,7 +25,7 @@ from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( identity_transition, observation_to_transition, - transition_to_action, + transition_to_robot_action, ) from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig @@ -76,7 +76,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline( ), ], to_transition=identity_transition, - to_output=transition_to_action, + to_output=transition_to_robot_action, ) # Build pipeline to convert joint observation to ee pose observation diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index f25835f96..c47667f4f 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -22,10 +22,10 @@ from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( - action_to_transition, identity_transition, observation_to_transition, - transition_to_action, + robot_action_to_transition, + transition_to_robot_action, ) from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig @@ -89,7 +89,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline( max_ee_twist_step_rad=0.50, ), ], - to_transition=action_to_transition, + to_transition=robot_action_to_transition, to_output=identity_transition, ) @@ -107,7 +107,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline( ), ], to_transition=identity_transition, - to_output=transition_to_action, + to_output=transition_to_robot_action, ) # Build pipeline to convert joint observation to ee pose observation diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index ffeaa7c2b..180fdfb3f 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -20,7 +20,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline -from lerobot.processor.converters import action_to_transition, transition_to_action +from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( AddRobotObservationAsComplimentaryData, @@ -59,8 +59,8 @@ robot_ee_to_joints_processor = RobotProcessorPipeline( initial_guess_current_joints=False, # Because replay is open loop ), ], - to_transition=action_to_transition, - to_output=transition_to_action, + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, ) robot_ee_to_joints_processor.reset() diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 5be126c32..f2125544a 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -17,7 +17,7 @@ import time from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline -from lerobot.processor.converters import action_to_transition, transition_to_action +from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( AddRobotObservationAsComplimentaryData, @@ -72,8 +72,8 @@ phone_to_robot_joints_processor = RobotProcessorPipeline( speed_factor=20.0, ), ], - to_transition=action_to_transition, - to_output=transition_to_action, + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, ) robot.connect() diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 66d074eb6..746d922e9 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -46,11 +46,13 @@ from .pipeline import ( IdentityProcessorStep, InfoProcessorStep, ObservationProcessorStep, + PolicyActionProcessorStep, PolicyProcessorPipeline, ProcessorKwargs, ProcessorStep, ProcessorStepRegistry, RewardProcessorStep, + RobotActionProcessorStep, RobotProcessorPipeline, TruncatedProcessorStep, ) @@ -81,10 +83,12 @@ __all__ = [ "NormalizerProcessorStep", "Numpy2TorchActionProcessorStep", "ObservationProcessorStep", + "PolicyActionProcessorStep", "PolicyProcessorPipeline", "ProcessorKwargs", "ProcessorStep", "ProcessorStepRegistry", + "RobotActionProcessorStep", "RenameObservationsProcessorStep", "RewardClassifierProcessorStep", "RewardProcessorStep", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 64bb1f6f3..1ba016b4e 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -27,11 +27,11 @@ from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from .core import EnvTransition +from .core import EnvTransition, PolicyAction from .pipeline import ( - ActionProcessorStep, ComplementaryDataProcessorStep, ObservationProcessorStep, + PolicyActionProcessorStep, ProcessorStep, ProcessorStepRegistry, ) @@ -39,14 +39,14 @@ from .pipeline import ( @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_action") -class AddBatchDimensionActionStep(ActionProcessorStep): +class AddBatchDimensionActionStep(PolicyActionProcessorStep): """ Processor step to add a batch dimension to a 1D tensor action. This is useful for creating a batch of size 1 from a single action sample. """ - def action(self, action: Tensor) -> Tensor: + def action(self, action: PolicyAction) -> PolicyAction: """ Adds a batch dimension to the action if it's a 1D tensor. @@ -56,7 +56,7 @@ class AddBatchDimensionActionStep(ActionProcessorStep): Returns: The action tensor with an added batch dimension. """ - if not isinstance(action, Tensor) or action.dim() != 1: + if action.dim() != 1: return action return action.unsqueeze(0) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index cdc1f8621..8456cad11 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -26,7 +26,7 @@ import torch from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey @singledispatch @@ -243,7 +243,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti def create_transition( observation: dict[str, Any] | None = None, - action: dict[str, Any] | None = None, + action: PolicyAction | RobotAction | None = None, reward: float = 0.0, done: bool = False, truncated: bool = False, @@ -276,9 +276,9 @@ def create_transition( } -def action_to_transition(action: dict[str, Any]) -> EnvTransition: +def robot_action_to_transition(action: RobotAction) -> EnvTransition: """ - Convert a raw action dictionary into a standardized `EnvTransition`. + Convert a raw robot action dictionary into a standardized `EnvTransition`. The keys in the action dictionary are prefixed with "action." and stored under the `ACTION` key in the transition. Values are converted to tensors, except for @@ -315,9 +315,9 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition: return create_transition(observation={**state, **image_observations}, action={}) -def transition_to_action(transition: EnvTransition) -> dict[str, Any]: +def transition_to_robot_action(transition: EnvTransition) -> RobotAction: """ - Extract a raw action dictionary for a robot from an `EnvTransition`. + Extract a raw robot action dictionary for a robot from an `EnvTransition`. This function searches for keys in the format "action.*.pos" or "action.*.vel" and converts them into a flat dictionary suitable for sending to a robot controller. @@ -460,6 +460,10 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: if not isinstance(batch, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") + action = batch.get("action") + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + # Extract observation and complementary data keys. observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} complementary_data = _extract_complementary_data(batch) diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py index a60a52d02..9a16cbaff 100644 --- a/src/lerobot/processor/core.py +++ b/src/lerobot/processor/core.py @@ -17,8 +17,9 @@ from __future__ import annotations from enum import Enum -from typing import Any, TypedDict +from typing import Any, TypeAlias, TypedDict +import numpy as np import torch @@ -35,11 +36,16 @@ class TransitionKey(str, Enum): COMPLEMENTARY_DATA = "complementary_data" +PolicyAction: TypeAlias = torch.Tensor +RobotAction: TypeAlias = dict[str, Any] +EnvAction: TypeAlias = np.ndarray + + EnvTransition = TypedDict( "EnvTransition", { TransitionKey.OBSERVATION.value: dict[str, Any] | None, - TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None, TransitionKey.REWARD.value: float | torch.Tensor | None, TransitionKey.DONE.value: bool | torch.Tensor | None, TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index 0135705bd..f081fd35d 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -16,11 +16,10 @@ from dataclasses import dataclass -from torch import Tensor - from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from .pipeline import ActionProcessorStep, ProcessorStepRegistry +from .core import PolicyAction, RobotAction +from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep @ProcessorStepRegistry.register("map_tensor_to_delta_action_dict") @@ -40,7 +39,10 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep): use_gripper: bool = True - def action(self, action: Tensor) -> dict: + def action(self, action: PolicyAction) -> RobotAction: + if not isinstance(action, PolicyAction): + raise ValueError("Only PolicyAction is supported for this processor") + if action.dim() > 1: action = action.squeeze(0) @@ -69,7 +71,7 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep): @ProcessorStepRegistry.register("map_delta_action_to_robot_action") @dataclass -class MapDeltaActionToRobotActionStep(ActionProcessorStep): +class MapDeltaActionToRobotActionStep(RobotActionProcessorStep): """ Maps delta actions from teleoperators to robot target actions for inverse kinematics. @@ -89,7 +91,7 @@ class MapDeltaActionToRobotActionStep(ActionProcessorStep): rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise - def action(self, action: dict) -> dict: + def action(self, action: RobotAction) -> RobotAction: # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices delta_x = action.pop("delta_x", 0.0) diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index ffe8a6af7..467b194bc 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -27,7 +27,7 @@ import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.utils.utils import get_safe_torch_device -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ProcessorStep, ProcessorStepRegistry @@ -129,6 +129,10 @@ class DeviceProcessorStep(ProcessorStep): A new `EnvTransition` object with all tensors moved to the target device and dtype. """ new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) + + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}") simple_tensor_keys = [ TransitionKey.ACTION, diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index af728c3e9..b8bafe51f 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -16,12 +16,10 @@ from dataclasses import dataclass -import numpy as np -import torch - from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import to_tensor +from .core import EnvAction, PolicyAction from .pipeline import ActionProcessorStep, ProcessorStepRegistry @@ -42,10 +40,10 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep): squeeze_batch_dim: bool = True - def action(self, action: torch.Tensor) -> np.ndarray: - if not isinstance(action, torch.Tensor): + def action(self, action: PolicyAction) -> EnvAction: + if not isinstance(action, PolicyAction): raise TypeError( - f"Expected torch.Tensor or None, got {type(action).__name__}. " + f"Expected PolicyAction or None, got {type(action).__name__}. " "Use appropriate processor for non-tensor actions." ) @@ -80,8 +78,8 @@ class Numpy2TorchActionProcessorStep(ActionProcessorStep): by a policy or model. """ - def action(self, action: np.ndarray) -> torch.Tensor: - if not isinstance(action, np.ndarray): + def action(self, action: EnvAction) -> PolicyAction: + if not isinstance(action, EnvAction): raise TypeError( f"Expected np.ndarray or None, got {type(action).__name__}. " "Use appropriate processor for non-tensor actions." diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index fcf1aeca9..1f69c5be0 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -28,7 +28,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ( ComplementaryDataProcessorStep, InfoProcessorStep, @@ -416,8 +416,8 @@ class InterventionActionProcessorStep(ProcessorStep): reward, and termination status. """ action = transition.get(TransitionKey.ACTION) - if action is None: - return transition + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") # Get intervention signals from complementary data info = transition.get(TransitionKey.INFO, {}) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 7e9c6b527..9d2a93e3c 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -28,7 +28,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatur from lerobot.datasets.lerobot_dataset import LeRobotDataset from .converters import from_tensor_to_numpy, to_tensor -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry @@ -345,8 +345,14 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): # Handle action normalization. action = new_transition.get(TransitionKey.ACTION) - if action is not None: - new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) + + if action is None: + return new_transition + + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) return new_transition @@ -401,8 +407,13 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep): # Handle action unnormalization. action = new_transition.get(TransitionKey.ACTION) - if action is not None: - new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) + + if action is None: + return new_transition + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) return new_transition diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 9afff7e48..c3440ff36 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvTransition, TransitionKey +from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey # Type variable for generic processor output type TOutput = TypeVar("TOutput") @@ -859,7 +859,9 @@ class ActionProcessorStep(ProcessorStep, ABC): """ @abstractmethod - def action(self, action) -> Any | torch.Tensor: + def action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: """Process the action component. Args: @@ -878,6 +880,82 @@ class ActionProcessorStep(ProcessorStep, ABC): if action is None: raise ValueError("ActionProcessorStep requires an action in the transition.") + processed_action = self.action(action) + new_transition[TransitionKey.ACTION] = processed_action + raise ValueError("ActionProcessorStep requires an action in the transition.") + + +class RobotActionProcessorStep(ProcessorStep, ABC): + """Base class for processors that modify only the robot action component of a transition. + + Subclasses should override the `action` method to implement custom robot action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific robot action processing logic. + """ + + @abstractmethod + def action(self, action: RobotAction) -> RobotAction: + """Process the robot action component. + + Args: + action: The robot action to process + + Returns: + The processed robot action + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + # NOTE: We can't use isinstance(action, RobotAction) because RobotAction is a dict[str, Any] + # because Any is generic + if not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") + + processed_action = self.action(action=action) + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + +class PolicyActionProcessorStep(ProcessorStep, ABC): + """Base class for processors that modify only the policy action component of a transition. + + Subclasses should override the `action` method to implement custom policy action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific policy action processing logic. + """ + + @abstractmethod + def action(self, action: PolicyAction) -> PolicyAction: + """Process the policy action component. + + Args: + action: The policy action to process + + Returns: + The processed policy action + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + processed_action = self.action(action) new_transition[TransitionKey.ACTION] = processed_action return new_transition diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 9888c8411..e0f40f6b4 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -85,11 +85,11 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import ( - action_to_transition, identity_transition, observation_to_transition, - transition_to_action, + robot_action_to_transition, transition_to_dataset_frame, + transition_to_robot_action, ) from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 @@ -255,7 +255,9 @@ def record_loop( teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( teleop_action_processor or RobotProcessorPipeline( - steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition + steps=[IdentityProcessorStep()], + to_transition=robot_action_to_transition, + to_output=identity_transition, ) ) robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( @@ -263,7 +265,7 @@ def record_loop( or RobotProcessorPipeline( steps=[IdentityProcessorStep()], to_transition=identity_transition, - to_output=transition_to_action, + to_output=transition_to_robot_action, ) ) robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index ae2d01e04..c641a22d1 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -48,7 +48,7 @@ from pprint import pformat from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline -from lerobot.processor.converters import action_to_transition, transition_to_action +from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -97,8 +97,8 @@ def replay(cfg: ReplayConfig): # Initialize robot action processor with default if not provided robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline( steps=[IdentityProcessorStep()], - to_transition=action_to_transition, - to_output=transition_to_action, # type: ignore[arg-type] + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, # type: ignore[arg-type] ) # Reset processor diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 9db737cfa..7c0c1eed7 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -22,21 +22,22 @@ from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeatur from lerobot.constants import OBS_STATE from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( - ActionProcessorStep, ComplementaryDataProcessorStep, EnvTransition, ObservationProcessorStep, ProcessorStep, ProcessorStepRegistry, + RobotActionProcessorStep, TransitionKey, ) +from lerobot.processor.core import RobotAction from lerobot.robots.robot import Robot from lerobot.utils.rotation import Rotation @ProcessorStepRegistry.register("ee_reference_and_delta") @dataclass -class EEReferenceAndDelta(ActionProcessorStep): +class EEReferenceAndDelta(RobotActionProcessorStep): """ Computes a target end-effector pose from a relative delta command. @@ -72,7 +73,7 @@ class EEReferenceAndDelta(ActionProcessorStep): _prev_enabled: bool = field(default=False, init=False, repr=False) _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) - def action(self, action): + def action(self, action: RobotAction) -> RobotAction: new_action = action.copy() comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) @@ -171,7 +172,7 @@ class EEReferenceAndDelta(ActionProcessorStep): @ProcessorStepRegistry.register("ee_bounds_and_safety") @dataclass -class EEBoundsAndSafety(ActionProcessorStep): +class EEBoundsAndSafety(RobotActionProcessorStep): """ Clips the end-effector pose to predefined bounds and checks for unsafe jumps. @@ -192,7 +193,7 @@ class EEBoundsAndSafety(ActionProcessorStep): _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) - def action(self, act: dict) -> dict: + def action(self, act: RobotAction) -> RobotAction: x = act.get("ee.x", None) y = act.get("ee.y", None) z = act.get("ee.z", None) @@ -266,6 +267,10 @@ class InverseKinematicsEEToJoints(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: new_transition = transition.copy() act = new_transition.get(TransitionKey.ACTION) or {} + + if not isinstance(act, dict): + raise ValueError(f"Action should be a RobotAction type got {type(act)}") + comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} x = act.get("ee.x", None) @@ -361,6 +366,9 @@ class GripperVelocityToJoint(ProcessorStep): act = new_transition.get(TransitionKey.ACTION) or {} comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + if not isinstance(act, dict): + raise ValueError(f"Action should be a RobotAction type got {type(act)}") + if "gripper" not in act: raise ValueError("Required action key 'gripper' not found in transition") diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 44ef73278..a24c9fb0c 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -64,10 +64,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.configs import parser from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey from lerobot.processor.converters import ( - action_to_transition, identity_transition, observation_to_transition, - transition_to_action, + robot_action_to_transition, + transition_to_robot_action, ) from lerobot.robots import ( # noqa: F401 Robot, @@ -140,7 +140,9 @@ def teleop_loop( teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( teleop_action_processor or RobotProcessorPipeline( - steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition + steps=[IdentityProcessorStep()], + to_transition=robot_action_to_transition, + to_output=identity_transition, ) ) robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( @@ -148,7 +150,7 @@ def teleop_loop( or RobotProcessorPipeline( steps=[IdentityProcessorStep()], to_transition=identity_transition, - to_output=transition_to_action, # type: ignore[arg-type] + to_output=transition_to_robot_action, # type: ignore[arg-type] ) ) robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 8d1f5e20e..631ad7899 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -49,7 +49,7 @@ def test_batch_to_transition_observation_grouping(): "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), "observation.state": [1, 2, 3, 4], - "action": "action_data", + "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, "next.truncated": False, @@ -74,7 +74,7 @@ def test_batch_to_transition_observation_grouping(): assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) assert transition[TransitionKey.REWARD] == 1.5 assert transition[TransitionKey.DONE] assert not transition[TransitionKey.TRUNCATED] @@ -123,7 +123,7 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": "action_data", + "action": torch.tensor([1.0, 2.0]), "next.reward": 2.0, "next.done": False, "next.truncated": True, @@ -136,7 +136,7 @@ def test_no_observation_keys(): assert transition[TransitionKey.OBSERVATION] is None # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0])) assert transition[TransitionKey.REWARD] == 2.0 assert not transition[TransitionKey.DONE] assert transition[TransitionKey.TRUNCATED] @@ -144,7 +144,7 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["action"] == "action_data" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -153,13 +153,13 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": "minimal_action"} + batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} transition = batch_to_transition(batch) # Check observation assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionKey.ACTION] == "minimal_action" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults assert transition[TransitionKey.REWARD] == 0.0 @@ -171,7 +171,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["observation.state"] == "minimal_state" - assert reconstructed_batch["action"] == "minimal_action" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index 100c9648c..219a81578 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -38,7 +38,7 @@ def test_state_1d_to_2d(): # Test observation.state state_1d = torch.randn(7) observation = {OBS_STATE: state_1d} - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -54,7 +54,7 @@ def test_env_state_1d_to_2d(): # Test observation.environment_state env_state_1d = torch.randn(10) observation = {OBS_ENV_STATE: env_state_1d} - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -70,7 +70,7 @@ def test_image_3d_to_4d(): # Test observation.image image_3d = torch.randn(224, 224, 3) observation = {OBS_IMAGE: image_3d} - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d(): f"{OBS_IMAGES}.camera1": image1_3d, f"{OBS_IMAGES}.camera2": image2_3d, } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged(): OBS_ENV_STATE: env_state_2d, OBS_IMAGE: image_4d, } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged(): OBS_STATE: state_3d, OBS_IMAGE: image_5d, } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged(): "custom_key": 42, # Integer "another_key": {"nested": "dict"}, # Dict } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -180,7 +180,7 @@ def test_none_observation(): """Test processor handles None observation gracefully.""" processor = AddBatchDimensionProcessorStep() - transition = create_transition(observation={}, action={}) + transition = create_transition(observation={}, action=torch.empty(0)) result = processor(transition) assert result[TransitionKey.OBSERVATION] == {} @@ -191,7 +191,7 @@ def test_empty_observation(): processor = AddBatchDimensionProcessorStep() observation = {} - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) @@ -216,7 +216,7 @@ def test_mixed_observation(): "other_tensor": other_tensor, "non_tensor": "string_value", } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] @@ -243,7 +243,7 @@ def test_integration_with_robot_processor(): OBS_STATE: torch.randn(7), OBS_IMAGE: torch.randn(224, 224, 3), } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = pipeline(transition) processed_obs = result[TransitionKey.OBSERVATION] @@ -299,7 +299,7 @@ def test_save_and_load_pretrained(): # Test functionality of loaded processor observation = {OBS_STATE: torch.randn(5)} - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = loaded_pipeline(transition) assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) @@ -333,7 +333,7 @@ def test_registry_based_save_load(): OBS_STATE: torch.randn(3), OBS_IMAGE: torch.randn(100, 100, 3), } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = loaded_pipeline(transition) processed_obs = result[TransitionKey.OBSERVATION] @@ -355,7 +355,7 @@ def test_device_compatibility(): OBS_STATE: state_1d, OBS_IMAGE: image_3d, } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] @@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors(): OBS_STATE: scalar_tensor, "scalar_value": scalar_tensor, } - transition = create_transition(observation=observation, action={}) + transition = create_transition(observation=observation, action=torch.empty(0)) result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] @@ -490,42 +490,43 @@ def test_action_scalar_tensor(): assert torch.equal(result[TransitionKey.ACTION], action_scalar) -def test_action_non_tensor(): - """Test that non-tensor actions remain unchanged.""" +def test_action_non_tensor_raises_error(): + """Test that non-tensor actions raise ValueError for PolicyAction processors.""" processor = AddBatchDimensionProcessorStep() - # List action + # List action should raise error action_list = [0.1, 0.2, 0.3, 0.4] - transition = create_transition(action=action_list, observation={}) - result = processor(transition) - assert result[TransitionKey.ACTION] == action_list + transition = create_transition(action=action_list) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) - # Numpy array action (as Python object, not converted) + # Numpy array action should raise error action_numpy = np.array([1, 2, 3, 4]) - transition = create_transition(action=action_numpy, observation={}) - result = processor(transition) - assert np.array_equal(result[TransitionKey.ACTION], action_numpy) + transition = create_transition(action=action_numpy) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) - # String action (edge case) + # String action should raise error action_string = "forward" - transition = create_transition(action=action_string, observation={}) - result = processor(transition) - assert result[TransitionKey.ACTION] == action_string + transition = create_transition(action=action_string) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) - # Dict action (structured action) + # Dict action should raise error action_dict = {"linear": [0.5, 0.0], "angular": 0.2} - transition = create_transition(action=action_dict, observation={}) - result = processor(transition) - assert result[TransitionKey.ACTION] == action_dict + transition = create_transition(action=action_dict) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) def test_action_none(): - """Test that None action is handled correctly.""" + """Test that empty action tensor is handled correctly.""" processor = AddBatchDimensionProcessorStep() - transition = create_transition(action={}, observation={}) + transition = create_transition(action=torch.empty(0), observation={}) result = processor(transition) - assert result[TransitionKey.ACTION] == {} + # Empty 1D tensor becomes empty 2D tensor with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 0) def test_action_with_observation(): @@ -630,7 +631,9 @@ def test_task_string_to_list(): # Create complementary data with string task complementary_data = {"task": "pick_cube"} - transition = create_transition(action={}, observation={}, complementary_data=complementary_data) + transition = create_transition( + action=torch.empty(0), observation={}, complementary_data=complementary_data + ) result = processor(transition) @@ -647,14 +650,18 @@ def test_task_string_validation(): # Valid string task - should be converted to list complementary_data = {"task": "valid_task"} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert processed_comp_data["task"] == ["valid_task"] # Valid list of strings - should remain unchanged complementary_data = {"task": ["task1", "task2"]} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] assert processed_comp_data["task"] == ["task1", "task2"] @@ -676,7 +683,9 @@ def test_task_list_of_strings(): for task_list in test_lists: complementary_data = {"task": task_list} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -690,7 +699,7 @@ def test_complementary_data_none(): """Test processor handles None complementary_data gracefully.""" processor = AddBatchDimensionProcessorStep() - transition = create_transition(complementary_data=None, action={}, observation={}) + transition = create_transition(complementary_data=None, action=torch.empty(0), observation={}) result = processor(transition) assert result[TransitionKey.COMPLEMENTARY_DATA] == {} @@ -701,7 +710,9 @@ def test_complementary_data_empty(): processor = AddBatchDimensionProcessorStep() complementary_data = {} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -717,7 +728,9 @@ def test_complementary_data_no_task(): "timestamp": 1234567890.0, "extra_info": "some data", } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -736,7 +749,9 @@ def test_complementary_data_mixed(): "difficulty": "hard", "metadata": {"scene": "kitchen"}, } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -803,7 +818,9 @@ def test_task_comprehensive_string_cases(): # Test that all string tasks get properly batched for task in string_tasks: complementary_data = {"task": task} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -825,7 +842,9 @@ def test_task_comprehensive_string_cases(): for task_list in list_tasks: complementary_data = {"task": task_list} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -845,7 +864,9 @@ def test_task_preserves_other_keys(): "config": {"speed": "slow", "precision": "high"}, "metrics": [1.0, 2.0, 3.0], } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -869,7 +890,9 @@ def test_index_scalar_to_1d(): # Create 0D index tensor (scalar) index_0d = torch.tensor(42, dtype=torch.int64) complementary_data = {"index": index_0d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -886,7 +909,9 @@ def test_task_index_scalar_to_1d(): # Create 0D task_index tensor (scalar) task_index_0d = torch.tensor(7, dtype=torch.int64) complementary_data = {"task_index": task_index_0d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -908,7 +933,9 @@ def test_index_and_task_index_together(): "task_index": task_index_0d, "task": "pick_object", } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -936,13 +963,17 @@ def test_index_already_batched(): # Test 1D (already batched) complementary_data = {"index": index_1d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d) # Test 2D complementary_data = {"index": index_2d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d) @@ -957,13 +988,17 @@ def test_task_index_already_batched(): # Test 1D (already batched) complementary_data = {"task_index": task_index_1d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d) # Test 2D complementary_data = {"task_index": task_index_2d} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d) @@ -976,7 +1011,9 @@ def test_index_non_tensor_unchanged(): "index": 42, # Plain int, not tensor "task_index": [1, 2, 3], # List, not tensor } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -999,7 +1036,9 @@ def test_index_dtype_preservation(): "index": index_0d, "task_index": task_index_0d, } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -1062,7 +1101,9 @@ def test_index_device_compatibility(): "index": index_0d, "task_index": task_index_0d, } - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] @@ -1081,7 +1122,9 @@ def test_empty_index_tensor(): # Empty 0D tensor doesn't make sense, but test empty 1D index_empty = torch.tensor([], dtype=torch.int64) complementary_data = {"index": index_empty} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) result = processor(transition) @@ -1116,7 +1159,9 @@ def test_task_processing_creates_new_transition(): processor = AddBatchDimensionProcessorStep() complementary_data = {"task": "sort_objects"} - transition = create_transition(complementary_data=complementary_data, observation={}, action={}) + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) # Store reference to original transition and complementary_data original_transition = transition diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index dcb450cd1..0a28320ae 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -329,14 +329,14 @@ def test_min_max_unnormalization(action_stats_min_max): assert torch.allclose(unnormalized_action, expected) -def test_numpy_action_input(action_stats_mean_std): +def test_tensor_action_input(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) - normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32) transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index 5e26b5c8f..8d2bd8453 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -371,12 +371,12 @@ def test_sac_processor_edge_cases(): assert processed[TransitionKey.OBSERVATION] == {} assert processed[TransitionKey.ACTION].shape == (1, 5) - # Test with None action - transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action={}) + # Test with zero action (representing "null" action) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5)) processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) - # When action is None, it may still be present with None value - assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None + # Action should be present and batched, even if it's zeros + assert processed[TransitionKey.ACTION].shape == (1, 5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")