refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908)

* refactor(processor): split action from policy, robots and environment

- Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions.
- Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention.
- Enhanced type annotations for action parameters to improve code readability and maintainability.

* refactor(converters): rename robot_transition_to_action to transition_to_robot_action

- Updated function names across multiple files to improve clarity and consistency in processing robot actions.
- Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention.
- Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence.

* refactor(converters): update references to transition_to_robot_action

- Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions.
- Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability.

* refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep

- Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing.
- This modification enhances the clarity of the class's role in the processing pipeline.

* fix(processor): main action processor can take also EnvAction

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-09-10 22:40:37 +02:00
committed by GitHub
parent 6745958362
commit 9183083e75
22 changed files with 303 additions and 139 deletions
+2 -2
View File
@@ -25,7 +25,7 @@ from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import ( from lerobot.processor.converters import (
identity_transition, identity_transition,
observation_to_transition, observation_to_transition,
transition_to_action, transition_to_robot_action,
) )
from lerobot.record import record_loop from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
@@ -76,7 +76,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
), ),
], ],
to_transition=identity_transition, to_transition=identity_transition,
to_output=transition_to_action, to_output=transition_to_robot_action,
) )
# Build pipeline to convert joint observation to ee pose observation # Build pipeline to convert joint observation to ee pose observation
+4 -4
View File
@@ -22,10 +22,10 @@ from lerobot.datasets.utils import combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import ( from lerobot.processor.converters import (
action_to_transition,
identity_transition, identity_transition,
observation_to_transition, observation_to_transition,
transition_to_action, robot_action_to_transition,
transition_to_robot_action,
) )
from lerobot.record import record_loop from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
@@ -89,7 +89,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
max_ee_twist_step_rad=0.50, max_ee_twist_step_rad=0.50,
), ),
], ],
to_transition=action_to_transition, to_transition=robot_action_to_transition,
to_output=identity_transition, to_output=identity_transition,
) )
@@ -107,7 +107,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
), ),
], ],
to_transition=identity_transition, to_transition=identity_transition,
to_output=transition_to_action, to_output=transition_to_robot_action,
) )
# Build pipeline to convert joint observation to ee pose observation # Build pipeline to convert joint observation to ee pose observation
+3 -3
View File
@@ -20,7 +20,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData, AddRobotObservationAsComplimentaryData,
@@ -59,8 +59,8 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
initial_guess_current_joints=False, # Because replay is open loop initial_guess_current_joints=False, # Because replay is open loop
), ),
], ],
to_transition=action_to_transition, to_transition=robot_action_to_transition,
to_output=transition_to_action, to_output=transition_to_robot_action,
) )
robot_ee_to_joints_processor.reset() robot_ee_to_joints_processor.reset()
+3 -3
View File
@@ -17,7 +17,7 @@ import time
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import RobotProcessorPipeline from lerobot.processor import RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import ( from lerobot.robots.so100_follower.robot_kinematic_processor import (
AddRobotObservationAsComplimentaryData, AddRobotObservationAsComplimentaryData,
@@ -72,8 +72,8 @@ phone_to_robot_joints_processor = RobotProcessorPipeline(
speed_factor=20.0, speed_factor=20.0,
), ),
], ],
to_transition=action_to_transition, to_transition=robot_action_to_transition,
to_output=transition_to_action, to_output=transition_to_robot_action,
) )
robot.connect() robot.connect()
+4
View File
@@ -46,11 +46,13 @@ from .pipeline import (
IdentityProcessorStep, IdentityProcessorStep,
InfoProcessorStep, InfoProcessorStep,
ObservationProcessorStep, ObservationProcessorStep,
PolicyActionProcessorStep,
PolicyProcessorPipeline, PolicyProcessorPipeline,
ProcessorKwargs, ProcessorKwargs,
ProcessorStep, ProcessorStep,
ProcessorStepRegistry, ProcessorStepRegistry,
RewardProcessorStep, RewardProcessorStep,
RobotActionProcessorStep,
RobotProcessorPipeline, RobotProcessorPipeline,
TruncatedProcessorStep, TruncatedProcessorStep,
) )
@@ -81,10 +83,12 @@ __all__ = [
"NormalizerProcessorStep", "NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep", "Numpy2TorchActionProcessorStep",
"ObservationProcessorStep", "ObservationProcessorStep",
"PolicyActionProcessorStep",
"PolicyProcessorPipeline", "PolicyProcessorPipeline",
"ProcessorKwargs", "ProcessorKwargs",
"ProcessorStep", "ProcessorStep",
"ProcessorStepRegistry", "ProcessorStepRegistry",
"RobotActionProcessorStep",
"RenameObservationsProcessorStep", "RenameObservationsProcessorStep",
"RewardClassifierProcessorStep", "RewardClassifierProcessorStep",
"RewardProcessorStep", "RewardProcessorStep",
+5 -5
View File
@@ -27,11 +27,11 @@ from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .core import EnvTransition from .core import EnvTransition, PolicyAction
from .pipeline import ( from .pipeline import (
ActionProcessorStep,
ComplementaryDataProcessorStep, ComplementaryDataProcessorStep,
ObservationProcessorStep, ObservationProcessorStep,
PolicyActionProcessorStep,
ProcessorStep, ProcessorStep,
ProcessorStepRegistry, ProcessorStepRegistry,
) )
@@ -39,14 +39,14 @@ from .pipeline import (
@dataclass @dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_action") @ProcessorStepRegistry.register(name="to_batch_processor_action")
class AddBatchDimensionActionStep(ActionProcessorStep): class AddBatchDimensionActionStep(PolicyActionProcessorStep):
""" """
Processor step to add a batch dimension to a 1D tensor action. Processor step to add a batch dimension to a 1D tensor action.
This is useful for creating a batch of size 1 from a single action sample. This is useful for creating a batch of size 1 from a single action sample.
""" """
def action(self, action: Tensor) -> Tensor: def action(self, action: PolicyAction) -> PolicyAction:
""" """
Adds a batch dimension to the action if it's a 1D tensor. Adds a batch dimension to the action if it's a 1D tensor.
@@ -56,7 +56,7 @@ class AddBatchDimensionActionStep(ActionProcessorStep):
Returns: Returns:
The action tensor with an added batch dimension. The action tensor with an added batch dimension.
""" """
if not isinstance(action, Tensor) or action.dim() != 1: if action.dim() != 1:
return action return action
return action.unsqueeze(0) return action.unsqueeze(0)
+10 -6
View File
@@ -26,7 +26,7 @@ import torch
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from .core import EnvTransition, TransitionKey from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
@singledispatch @singledispatch
@@ -243,7 +243,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti
def create_transition( def create_transition(
observation: dict[str, Any] | None = None, observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None, action: PolicyAction | RobotAction | None = None,
reward: float = 0.0, reward: float = 0.0,
done: bool = False, done: bool = False,
truncated: bool = False, truncated: bool = False,
@@ -276,9 +276,9 @@ def create_transition(
} }
def action_to_transition(action: dict[str, Any]) -> EnvTransition: def robot_action_to_transition(action: RobotAction) -> EnvTransition:
""" """
Convert a raw action dictionary into a standardized `EnvTransition`. Convert a raw robot action dictionary into a standardized `EnvTransition`.
The keys in the action dictionary are prefixed with "action." and stored under The keys in the action dictionary are prefixed with "action." and stored under
the `ACTION` key in the transition. Values are converted to tensors, except for the `ACTION` key in the transition. Values are converted to tensors, except for
@@ -315,9 +315,9 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
return create_transition(observation={**state, **image_observations}, action={}) return create_transition(observation={**state, **image_observations}, action={})
def transition_to_action(transition: EnvTransition) -> dict[str, Any]: def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
""" """
Extract a raw action dictionary for a robot from an `EnvTransition`. Extract a raw robot action dictionary for a robot from an `EnvTransition`.
This function searches for keys in the format "action.*.pos" or "action.*.vel" This function searches for keys in the format "action.*.pos" or "action.*.vel"
and converts them into a flat dictionary suitable for sending to a robot controller. and converts them into a flat dictionary suitable for sending to a robot controller.
@@ -460,6 +460,10 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
if not isinstance(batch, dict): if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
action = batch.get("action")
if action is not None and not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
# Extract observation and complementary data keys. # Extract observation and complementary data keys.
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
complementary_data = _extract_complementary_data(batch) complementary_data = _extract_complementary_data(batch)
+8 -2
View File
@@ -17,8 +17,9 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Any, TypedDict from typing import Any, TypeAlias, TypedDict
import numpy as np
import torch import torch
@@ -35,11 +36,16 @@ class TransitionKey(str, Enum):
COMPLEMENTARY_DATA = "complementary_data" COMPLEMENTARY_DATA = "complementary_data"
PolicyAction: TypeAlias = torch.Tensor
RobotAction: TypeAlias = dict[str, Any]
EnvAction: TypeAlias = np.ndarray
EnvTransition = TypedDict( EnvTransition = TypedDict(
"EnvTransition", "EnvTransition",
{ {
TransitionKey.OBSERVATION.value: dict[str, Any] | None, TransitionKey.OBSERVATION.value: dict[str, Any] | None,
TransitionKey.ACTION.value: Any | torch.Tensor | None, TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
TransitionKey.REWARD.value: float | torch.Tensor | None, TransitionKey.REWARD.value: float | torch.Tensor | None,
TransitionKey.DONE.value: bool | torch.Tensor | None, TransitionKey.DONE.value: bool | torch.Tensor | None,
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
@@ -16,11 +16,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from torch import Tensor
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from .pipeline import ActionProcessorStep, ProcessorStepRegistry from .core import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict") @ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@@ -40,7 +39,10 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
use_gripper: bool = True use_gripper: bool = True
def action(self, action: Tensor) -> dict: def action(self, action: PolicyAction) -> RobotAction:
if not isinstance(action, PolicyAction):
raise ValueError("Only PolicyAction is supported for this processor")
if action.dim() > 1: if action.dim() > 1:
action = action.squeeze(0) action = action.squeeze(0)
@@ -69,7 +71,7 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
@ProcessorStepRegistry.register("map_delta_action_to_robot_action") @ProcessorStepRegistry.register("map_delta_action_to_robot_action")
@dataclass @dataclass
class MapDeltaActionToRobotActionStep(ActionProcessorStep): class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
""" """
Maps delta actions from teleoperators to robot target actions for inverse kinematics. Maps delta actions from teleoperators to robot target actions for inverse kinematics.
@@ -89,7 +91,7 @@ class MapDeltaActionToRobotActionStep(ActionProcessorStep):
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
def action(self, action: dict) -> dict: def action(self, action: RobotAction) -> RobotAction:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop("delta_x", 0.0) delta_x = action.pop("delta_x", 0.0)
+5 -1
View File
@@ -27,7 +27,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.utils import get_safe_torch_device from lerobot.utils.utils import get_safe_torch_device
from .core import EnvTransition, TransitionKey from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry from .pipeline import ProcessorStep, ProcessorStepRegistry
@@ -129,6 +129,10 @@ class DeviceProcessorStep(ProcessorStep):
A new `EnvTransition` object with all tensors moved to the target device and dtype. A new `EnvTransition` object with all tensors moved to the target device and dtype.
""" """
new_transition = transition.copy() new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is not None and not isinstance(action, PolicyAction):
raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}")
simple_tensor_keys = [ simple_tensor_keys = [
TransitionKey.ACTION, TransitionKey.ACTION,
@@ -16,12 +16,10 @@
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import to_tensor from .converters import to_tensor
from .core import EnvAction, PolicyAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@@ -42,10 +40,10 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
squeeze_batch_dim: bool = True squeeze_batch_dim: bool = True
def action(self, action: torch.Tensor) -> np.ndarray: def action(self, action: PolicyAction) -> EnvAction:
if not isinstance(action, torch.Tensor): if not isinstance(action, PolicyAction):
raise TypeError( raise TypeError(
f"Expected torch.Tensor or None, got {type(action).__name__}. " f"Expected PolicyAction or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions." "Use appropriate processor for non-tensor actions."
) )
@@ -80,8 +78,8 @@ class Numpy2TorchActionProcessorStep(ActionProcessorStep):
by a policy or model. by a policy or model.
""" """
def action(self, action: np.ndarray) -> torch.Tensor: def action(self, action: EnvAction) -> PolicyAction:
if not isinstance(action, np.ndarray): if not isinstance(action, EnvAction):
raise TypeError( raise TypeError(
f"Expected np.ndarray or None, got {type(action).__name__}. " f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions." "Use appropriate processor for non-tensor actions."
+3 -3
View File
@@ -28,7 +28,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from .core import EnvTransition, TransitionKey from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import ( from .pipeline import (
ComplementaryDataProcessorStep, ComplementaryDataProcessorStep,
InfoProcessorStep, InfoProcessorStep,
@@ -416,8 +416,8 @@ class InterventionActionProcessorStep(ProcessorStep):
reward, and termination status. reward, and termination status.
""" """
action = transition.get(TransitionKey.ACTION) action = transition.get(TransitionKey.ACTION)
if action is None: if not isinstance(action, PolicyAction):
return transition raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
# Get intervention signals from complementary data # Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {}) info = transition.get(TransitionKey.INFO, {})
+14 -3
View File
@@ -28,7 +28,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatur
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from .converters import from_tensor_to_numpy, to_tensor from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, TransitionKey from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@@ -345,7 +345,13 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
# Handle action normalization. # Handle action normalization.
action = new_transition.get(TransitionKey.ACTION) action = new_transition.get(TransitionKey.ACTION)
if action is not None:
if action is None:
return new_transition
if not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
return new_transition return new_transition
@@ -401,7 +407,12 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
# Handle action unnormalization. # Handle action unnormalization.
action = new_transition.get(TransitionKey.ACTION) action = new_transition.get(TransitionKey.ACTION)
if action is not None:
if action is None:
return new_transition
if not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
return new_transition return new_transition
+80 -2
View File
@@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from .converters import batch_to_transition, create_transition, transition_to_batch from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvTransition, TransitionKey from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
# Type variable for generic processor output type # Type variable for generic processor output type
TOutput = TypeVar("TOutput") TOutput = TypeVar("TOutput")
@@ -859,7 +859,9 @@ class ActionProcessorStep(ProcessorStep, ABC):
""" """
@abstractmethod @abstractmethod
def action(self, action) -> Any | torch.Tensor: def action(
self, action: PolicyAction | RobotAction | EnvAction
) -> PolicyAction | RobotAction | EnvAction:
"""Process the action component. """Process the action component.
Args: Args:
@@ -878,6 +880,82 @@ class ActionProcessorStep(ProcessorStep, ABC):
if action is None: if action is None:
raise ValueError("ActionProcessorStep requires an action in the transition.") raise ValueError("ActionProcessorStep requires an action in the transition.")
processed_action = self.action(action)
new_transition[TransitionKey.ACTION] = processed_action
raise ValueError("ActionProcessorStep requires an action in the transition.")
class RobotActionProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the robot action component of a transition.
Subclasses should override the `action` method to implement custom robot action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific robot action processing logic.
"""
@abstractmethod
def action(self, action: RobotAction) -> RobotAction:
"""Process the robot action component.
Args:
action: The robot action to process
Returns:
The processed robot action
"""
...
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
# NOTE: We can't use isinstance(action, RobotAction) because RobotAction is a dict[str, Any]
# because Any is generic
if not isinstance(action, dict):
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
processed_action = self.action(action=action)
new_transition[TransitionKey.ACTION] = processed_action
return new_transition
class PolicyActionProcessorStep(ProcessorStep, ABC):
"""Base class for processors that modify only the policy action component of a transition.
Subclasses should override the `action` method to implement custom policy action processing.
This class handles the boilerplate of extracting and reinserting the processed action
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
By inheriting from this class, you avoid writing repetitive code to handle transition dict
manipulation, focusing only on the specific policy action processing logic.
"""
@abstractmethod
def action(self, action: PolicyAction) -> PolicyAction:
"""Process the policy action component.
Args:
action: The policy action to process
Returns:
The processed policy action
"""
...
def __call__(self, transition: EnvTransition) -> EnvTransition:
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
if not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
processed_action = self.action(action) processed_action = self.action(action)
new_transition[TransitionKey.ACTION] = processed_action new_transition[TransitionKey.ACTION] = processed_action
return new_transition return new_transition
+6 -4
View File
@@ -85,11 +85,11 @@ from lerobot.processor import (
TransitionKey, TransitionKey,
) )
from lerobot.processor.converters import ( from lerobot.processor.converters import (
action_to_transition,
identity_transition, identity_transition,
observation_to_transition, observation_to_transition,
transition_to_action, robot_action_to_transition,
transition_to_dataset_frame, transition_to_dataset_frame,
transition_to_robot_action,
) )
from lerobot.processor.rename_processor import rename_stats from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
@@ -255,7 +255,9 @@ def record_loop(
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
teleop_action_processor teleop_action_processor
or RobotProcessorPipeline( or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition steps=[IdentityProcessorStep()],
to_transition=robot_action_to_transition,
to_output=identity_transition,
) )
) )
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
@@ -263,7 +265,7 @@ def record_loop(
or RobotProcessorPipeline( or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], steps=[IdentityProcessorStep()],
to_transition=identity_transition, to_transition=identity_transition,
to_output=transition_to_action, to_output=transition_to_robot_action,
) )
) )
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
+3 -3
View File
@@ -48,7 +48,7 @@ from pprint import pformat
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
from lerobot.processor.converters import action_to_transition, transition_to_action from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
@@ -97,8 +97,8 @@ def replay(cfg: ReplayConfig):
# Initialize robot action processor with default if not provided # Initialize robot action processor with default if not provided
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline( robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], steps=[IdentityProcessorStep()],
to_transition=action_to_transition, to_transition=robot_action_to_transition,
to_output=transition_to_action, # type: ignore[arg-type] to_output=transition_to_robot_action, # type: ignore[arg-type]
) )
# Reset processor # Reset processor
@@ -22,21 +22,22 @@ from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeatur
from lerobot.constants import OBS_STATE from lerobot.constants import OBS_STATE
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import ( from lerobot.processor import (
ActionProcessorStep,
ComplementaryDataProcessorStep, ComplementaryDataProcessorStep,
EnvTransition, EnvTransition,
ObservationProcessorStep, ObservationProcessorStep,
ProcessorStep, ProcessorStep,
ProcessorStepRegistry, ProcessorStepRegistry,
RobotActionProcessorStep,
TransitionKey, TransitionKey,
) )
from lerobot.processor.core import RobotAction
from lerobot.robots.robot import Robot from lerobot.robots.robot import Robot
from lerobot.utils.rotation import Rotation from lerobot.utils.rotation import Rotation
@ProcessorStepRegistry.register("ee_reference_and_delta") @ProcessorStepRegistry.register("ee_reference_and_delta")
@dataclass @dataclass
class EEReferenceAndDelta(ActionProcessorStep): class EEReferenceAndDelta(RobotActionProcessorStep):
""" """
Computes a target end-effector pose from a relative delta command. Computes a target end-effector pose from a relative delta command.
@@ -72,7 +73,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
_prev_enabled: bool = field(default=False, init=False, repr=False) _prev_enabled: bool = field(default=False, init=False, repr=False)
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action): def action(self, action: RobotAction) -> RobotAction:
new_action = action.copy() new_action = action.copy()
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
@@ -171,7 +172,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
@ProcessorStepRegistry.register("ee_bounds_and_safety") @ProcessorStepRegistry.register("ee_bounds_and_safety")
@dataclass @dataclass
class EEBoundsAndSafety(ActionProcessorStep): class EEBoundsAndSafety(RobotActionProcessorStep):
""" """
Clips the end-effector pose to predefined bounds and checks for unsafe jumps. Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
@@ -192,7 +193,7 @@ class EEBoundsAndSafety(ActionProcessorStep):
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False) _last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False) _last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict) -> dict: def action(self, act: RobotAction) -> RobotAction:
x = act.get("ee.x", None) x = act.get("ee.x", None)
y = act.get("ee.y", None) y = act.get("ee.y", None)
z = act.get("ee.z", None) z = act.get("ee.z", None)
@@ -266,6 +267,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
new_transition = transition.copy() new_transition = transition.copy()
act = new_transition.get(TransitionKey.ACTION) or {} act = new_transition.get(TransitionKey.ACTION) or {}
if not isinstance(act, dict):
raise ValueError(f"Action should be a RobotAction type got {type(act)}")
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get("ee.x", None) x = act.get("ee.x", None)
@@ -361,6 +366,9 @@ class GripperVelocityToJoint(ProcessorStep):
act = new_transition.get(TransitionKey.ACTION) or {} act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if not isinstance(act, dict):
raise ValueError(f"Action should be a RobotAction type got {type(act)}")
if "gripper" not in act: if "gripper" not in act:
raise ValueError("Required action key 'gripper' not found in transition") raise ValueError("Required action key 'gripper' not found in transition")
+6 -4
View File
@@ -64,10 +64,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey
from lerobot.processor.converters import ( from lerobot.processor.converters import (
action_to_transition,
identity_transition, identity_transition,
observation_to_transition, observation_to_transition,
transition_to_action, robot_action_to_transition,
transition_to_robot_action,
) )
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
@@ -140,7 +140,9 @@ def teleop_loop(
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
teleop_action_processor teleop_action_processor
or RobotProcessorPipeline( or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition steps=[IdentityProcessorStep()],
to_transition=robot_action_to_transition,
to_output=identity_transition,
) )
) )
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
@@ -148,7 +150,7 @@ def teleop_loop(
or RobotProcessorPipeline( or RobotProcessorPipeline(
steps=[IdentityProcessorStep()], steps=[IdentityProcessorStep()],
to_transition=identity_transition, to_transition=identity_transition,
to_output=transition_to_action, # type: ignore[arg-type] to_output=transition_to_robot_action, # type: ignore[arg-type]
) )
) )
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
+8 -8
View File
@@ -49,7 +49,7 @@ def test_batch_to_transition_observation_grouping():
"observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4], "observation.state": [1, 2, 3, 4],
"action": "action_data", "action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
"next.reward": 1.5, "next.reward": 1.5,
"next.done": True, "next.done": True,
"next.truncated": False, "next.truncated": False,
@@ -74,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields # Check other fields
assert transition[TransitionKey.ACTION] == "action_data" assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
assert transition[TransitionKey.REWARD] == 1.5 assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE] assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED] assert not transition[TransitionKey.TRUNCATED]
@@ -123,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
def test_no_observation_keys(): def test_no_observation_keys():
"""Test behavior when there are no observation.* keys.""" """Test behavior when there are no observation.* keys."""
batch = { batch = {
"action": "action_data", "action": torch.tensor([1.0, 2.0]),
"next.reward": 2.0, "next.reward": 2.0,
"next.done": False, "next.done": False,
"next.truncated": True, "next.truncated": True,
@@ -136,7 +136,7 @@ def test_no_observation_keys():
assert transition[TransitionKey.OBSERVATION] is None assert transition[TransitionKey.OBSERVATION] is None
# Check other fields # Check other fields
assert transition[TransitionKey.ACTION] == "action_data" assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0]))
assert transition[TransitionKey.REWARD] == 2.0 assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE] assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED] assert transition[TransitionKey.TRUNCATED]
@@ -144,7 +144,7 @@ def test_no_observation_keys():
# Round trip should work # Round trip should work
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data" assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
assert reconstructed_batch["next.reward"] == 2.0 assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"] assert reconstructed_batch["next.truncated"]
@@ -153,13 +153,13 @@ def test_no_observation_keys():
def test_minimal_batch(): def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action.""" """Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"} batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
transition = batch_to_transition(batch) transition = batch_to_transition(batch)
# Check observation # Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action" assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
# Check defaults # Check defaults
assert transition[TransitionKey.REWARD] == 0.0 assert transition[TransitionKey.REWARD] == 0.0
@@ -171,7 +171,7 @@ def test_minimal_batch():
# Round trip # Round trip
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state" assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action" assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
assert reconstructed_batch["next.reward"] == 0.0 assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"] assert not reconstructed_batch["next.truncated"]
+104 -59
View File
@@ -38,7 +38,7 @@ def test_state_1d_to_2d():
# Test observation.state # Test observation.state
state_1d = torch.randn(7) state_1d = torch.randn(7)
observation = {OBS_STATE: state_1d} observation = {OBS_STATE: state_1d}
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -54,7 +54,7 @@ def test_env_state_1d_to_2d():
# Test observation.environment_state # Test observation.environment_state
env_state_1d = torch.randn(10) env_state_1d = torch.randn(10)
observation = {OBS_ENV_STATE: env_state_1d} observation = {OBS_ENV_STATE: env_state_1d}
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -70,7 +70,7 @@ def test_image_3d_to_4d():
# Test observation.image # Test observation.image
image_3d = torch.randn(224, 224, 3) image_3d = torch.randn(224, 224, 3)
observation = {OBS_IMAGE: image_3d} observation = {OBS_IMAGE: image_3d}
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d():
f"{OBS_IMAGES}.camera1": image1_3d, f"{OBS_IMAGES}.camera1": image1_3d,
f"{OBS_IMAGES}.camera2": image2_3d, f"{OBS_IMAGES}.camera2": image2_3d,
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged():
OBS_ENV_STATE: env_state_2d, OBS_ENV_STATE: env_state_2d,
OBS_IMAGE: image_4d, OBS_IMAGE: image_4d,
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged():
OBS_STATE: state_3d, OBS_STATE: state_3d,
OBS_IMAGE: image_5d, OBS_IMAGE: image_5d,
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged():
"custom_key": 42, # Integer "custom_key": 42, # Integer
"another_key": {"nested": "dict"}, # Dict "another_key": {"nested": "dict"}, # Dict
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -180,7 +180,7 @@ def test_none_observation():
"""Test processor handles None observation gracefully.""" """Test processor handles None observation gracefully."""
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
transition = create_transition(observation={}, action={}) transition = create_transition(observation={}, action=torch.empty(0))
result = processor(transition) result = processor(transition)
assert result[TransitionKey.OBSERVATION] == {} assert result[TransitionKey.OBSERVATION] == {}
@@ -191,7 +191,7 @@ def test_empty_observation():
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
observation = {} observation = {}
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
@@ -216,7 +216,7 @@ def test_mixed_observation():
"other_tensor": other_tensor, "other_tensor": other_tensor,
"non_tensor": "string_value", "non_tensor": "string_value",
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
@@ -243,7 +243,7 @@ def test_integration_with_robot_processor():
OBS_STATE: torch.randn(7), OBS_STATE: torch.randn(7),
OBS_IMAGE: torch.randn(224, 224, 3), OBS_IMAGE: torch.randn(224, 224, 3),
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = pipeline(transition) result = pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
@@ -299,7 +299,7 @@ def test_save_and_load_pretrained():
# Test functionality of loaded processor # Test functionality of loaded processor
observation = {OBS_STATE: torch.randn(5)} observation = {OBS_STATE: torch.randn(5)}
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = loaded_pipeline(transition) result = loaded_pipeline(transition)
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
@@ -333,7 +333,7 @@ def test_registry_based_save_load():
OBS_STATE: torch.randn(3), OBS_STATE: torch.randn(3),
OBS_IMAGE: torch.randn(100, 100, 3), OBS_IMAGE: torch.randn(100, 100, 3),
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = loaded_pipeline(transition) result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
@@ -355,7 +355,7 @@ def test_device_compatibility():
OBS_STATE: state_1d, OBS_STATE: state_1d,
OBS_IMAGE: image_3d, OBS_IMAGE: image_3d,
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
@@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors():
OBS_STATE: scalar_tensor, OBS_STATE: scalar_tensor,
"scalar_value": scalar_tensor, "scalar_value": scalar_tensor,
} }
transition = create_transition(observation=observation, action={}) transition = create_transition(observation=observation, action=torch.empty(0))
result = processor(transition) result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
@@ -490,42 +490,43 @@ def test_action_scalar_tensor():
assert torch.equal(result[TransitionKey.ACTION], action_scalar) assert torch.equal(result[TransitionKey.ACTION], action_scalar)
def test_action_non_tensor(): def test_action_non_tensor_raises_error():
"""Test that non-tensor actions remain unchanged.""" """Test that non-tensor actions raise ValueError for PolicyAction processors."""
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
# List action # List action should raise error
action_list = [0.1, 0.2, 0.3, 0.4] action_list = [0.1, 0.2, 0.3, 0.4]
transition = create_transition(action=action_list, observation={}) transition = create_transition(action=action_list)
result = processor(transition) with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
assert result[TransitionKey.ACTION] == action_list processor(transition)
# Numpy array action (as Python object, not converted) # Numpy array action should raise error
action_numpy = np.array([1, 2, 3, 4]) action_numpy = np.array([1, 2, 3, 4])
transition = create_transition(action=action_numpy, observation={}) transition = create_transition(action=action_numpy)
result = processor(transition) with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
assert np.array_equal(result[TransitionKey.ACTION], action_numpy) processor(transition)
# String action (edge case) # String action should raise error
action_string = "forward" action_string = "forward"
transition = create_transition(action=action_string, observation={}) transition = create_transition(action=action_string)
result = processor(transition) with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
assert result[TransitionKey.ACTION] == action_string processor(transition)
# Dict action (structured action) # Dict action should raise error
action_dict = {"linear": [0.5, 0.0], "angular": 0.2} action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
transition = create_transition(action=action_dict, observation={}) transition = create_transition(action=action_dict)
result = processor(transition) with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
assert result[TransitionKey.ACTION] == action_dict processor(transition)
def test_action_none(): def test_action_none():
"""Test that None action is handled correctly.""" """Test that empty action tensor is handled correctly."""
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
transition = create_transition(action={}, observation={}) transition = create_transition(action=torch.empty(0), observation={})
result = processor(transition) result = processor(transition)
assert result[TransitionKey.ACTION] == {} # Empty 1D tensor becomes empty 2D tensor with batch dimension
assert result[TransitionKey.ACTION].shape == (1, 0)
def test_action_with_observation(): def test_action_with_observation():
@@ -630,7 +631,9 @@ def test_task_string_to_list():
# Create complementary data with string task # Create complementary data with string task
complementary_data = {"task": "pick_cube"} complementary_data = {"task": "pick_cube"}
transition = create_transition(action={}, observation={}, complementary_data=complementary_data) transition = create_transition(
action=torch.empty(0), observation={}, complementary_data=complementary_data
)
result = processor(transition) result = processor(transition)
@@ -647,14 +650,18 @@ def test_task_string_validation():
# Valid string task - should be converted to list # Valid string task - should be converted to list
complementary_data = {"task": "valid_task"} complementary_data = {"task": "valid_task"}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["valid_task"] assert processed_comp_data["task"] == ["valid_task"]
# Valid list of strings - should remain unchanged # Valid list of strings - should remain unchanged
complementary_data = {"task": ["task1", "task2"]} complementary_data = {"task": ["task1", "task2"]}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["task1", "task2"] assert processed_comp_data["task"] == ["task1", "task2"]
@@ -676,7 +683,9 @@ def test_task_list_of_strings():
for task_list in test_lists: for task_list in test_lists:
complementary_data = {"task": task_list} complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -690,7 +699,7 @@ def test_complementary_data_none():
"""Test processor handles None complementary_data gracefully.""" """Test processor handles None complementary_data gracefully."""
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
transition = create_transition(complementary_data=None, action={}, observation={}) transition = create_transition(complementary_data=None, action=torch.empty(0), observation={})
result = processor(transition) result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA] == {} assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
@@ -701,7 +710,9 @@ def test_complementary_data_empty():
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
complementary_data = {} complementary_data = {}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -717,7 +728,9 @@ def test_complementary_data_no_task():
"timestamp": 1234567890.0, "timestamp": 1234567890.0,
"extra_info": "some data", "extra_info": "some data",
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -736,7 +749,9 @@ def test_complementary_data_mixed():
"difficulty": "hard", "difficulty": "hard",
"metadata": {"scene": "kitchen"}, "metadata": {"scene": "kitchen"},
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -803,7 +818,9 @@ def test_task_comprehensive_string_cases():
# Test that all string tasks get properly batched # Test that all string tasks get properly batched
for task in string_tasks: for task in string_tasks:
complementary_data = {"task": task} complementary_data = {"task": task}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -825,7 +842,9 @@ def test_task_comprehensive_string_cases():
for task_list in list_tasks: for task_list in list_tasks:
complementary_data = {"task": task_list} complementary_data = {"task": task_list}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -845,7 +864,9 @@ def test_task_preserves_other_keys():
"config": {"speed": "slow", "precision": "high"}, "config": {"speed": "slow", "precision": "high"},
"metrics": [1.0, 2.0, 3.0], "metrics": [1.0, 2.0, 3.0],
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -869,7 +890,9 @@ def test_index_scalar_to_1d():
# Create 0D index tensor (scalar) # Create 0D index tensor (scalar)
index_0d = torch.tensor(42, dtype=torch.int64) index_0d = torch.tensor(42, dtype=torch.int64)
complementary_data = {"index": index_0d} complementary_data = {"index": index_0d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -886,7 +909,9 @@ def test_task_index_scalar_to_1d():
# Create 0D task_index tensor (scalar) # Create 0D task_index tensor (scalar)
task_index_0d = torch.tensor(7, dtype=torch.int64) task_index_0d = torch.tensor(7, dtype=torch.int64)
complementary_data = {"task_index": task_index_0d} complementary_data = {"task_index": task_index_0d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -908,7 +933,9 @@ def test_index_and_task_index_together():
"task_index": task_index_0d, "task_index": task_index_0d,
"task": "pick_object", "task": "pick_object",
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -936,13 +963,17 @@ def test_index_already_batched():
# Test 1D (already batched) # Test 1D (already batched)
complementary_data = {"index": index_1d} complementary_data = {"index": index_1d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
# Test 2D # Test 2D
complementary_data = {"index": index_2d} complementary_data = {"index": index_2d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
@@ -957,13 +988,17 @@ def test_task_index_already_batched():
# Test 1D (already batched) # Test 1D (already batched)
complementary_data = {"task_index": task_index_1d} complementary_data = {"task_index": task_index_1d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
# Test 2D # Test 2D
complementary_data = {"task_index": task_index_2d} complementary_data = {"task_index": task_index_2d}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d) assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
@@ -976,7 +1011,9 @@ def test_index_non_tensor_unchanged():
"index": 42, # Plain int, not tensor "index": 42, # Plain int, not tensor
"task_index": [1, 2, 3], # List, not tensor "task_index": [1, 2, 3], # List, not tensor
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -999,7 +1036,9 @@ def test_index_dtype_preservation():
"index": index_0d, "index": index_0d,
"task_index": task_index_0d, "task_index": task_index_0d,
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -1062,7 +1101,9 @@ def test_index_device_compatibility():
"index": index_0d, "index": index_0d,
"task_index": task_index_0d, "task_index": task_index_0d,
} }
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
@@ -1081,7 +1122,9 @@ def test_empty_index_tensor():
# Empty 0D tensor doesn't make sense, but test empty 1D # Empty 0D tensor doesn't make sense, but test empty 1D
index_empty = torch.tensor([], dtype=torch.int64) index_empty = torch.tensor([], dtype=torch.int64)
complementary_data = {"index": index_empty} complementary_data = {"index": index_empty}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
result = processor(transition) result = processor(transition)
@@ -1116,7 +1159,9 @@ def test_task_processing_creates_new_transition():
processor = AddBatchDimensionProcessorStep() processor = AddBatchDimensionProcessorStep()
complementary_data = {"task": "sort_objects"} complementary_data = {"task": "sort_objects"}
transition = create_transition(complementary_data=complementary_data, observation={}, action={}) transition = create_transition(
complementary_data=complementary_data, observation={}, action=torch.empty(0)
)
# Store reference to original transition and complementary_data # Store reference to original transition and complementary_data
original_transition = transition original_transition = transition
+2 -2
View File
@@ -329,14 +329,14 @@ def test_min_max_unnormalization(action_stats_min_max):
assert torch.allclose(unnormalized_action, expected) assert torch.allclose(unnormalized_action, expected)
def test_numpy_action_input(action_stats_mean_std): def test_tensor_action_input(action_stats_mean_std):
features = _create_action_features() features = _create_action_features()
norm_map = _create_action_norm_map_mean_std() norm_map = _create_action_norm_map_mean_std()
unnormalizer = UnnormalizerProcessorStep( unnormalizer = UnnormalizerProcessorStep(
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
) )
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32)
transition = create_transition(action=normalized_action) transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition) unnormalized_transition = unnormalizer(transition)
+4 -4
View File
@@ -371,12 +371,12 @@ def test_sac_processor_edge_cases():
assert processed[TransitionKey.OBSERVATION] == {} assert processed[TransitionKey.OBSERVATION] == {}
assert processed[TransitionKey.ACTION].shape == (1, 5) assert processed[TransitionKey.ACTION].shape == (1, 5)
# Test with None action # Test with zero action (representing "null" action)
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action={}) transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5))
processed = preprocessor(transition) processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value # Action should be present and batched, even if it's zeros
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None assert processed[TransitionKey.ACTION].shape == (1, 5)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")