mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908)
* refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -25,7 +25,7 @@ from lerobot.processor import RobotProcessorPipeline
|
|||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
identity_transition,
|
identity_transition,
|
||||||
observation_to_transition,
|
observation_to_transition,
|
||||||
transition_to_action,
|
transition_to_robot_action,
|
||||||
)
|
)
|
||||||
from lerobot.record import record_loop
|
from lerobot.record import record_loop
|
||||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||||
@@ -76,7 +76,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
to_transition=identity_transition,
|
to_transition=identity_transition,
|
||||||
to_output=transition_to_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert joint observation to ee pose observation
|
# Build pipeline to convert joint observation to ee pose observation
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ from lerobot.datasets.utils import combine_feature_dicts
|
|||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import RobotProcessorPipeline
|
from lerobot.processor import RobotProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
action_to_transition,
|
|
||||||
identity_transition,
|
identity_transition,
|
||||||
observation_to_transition,
|
observation_to_transition,
|
||||||
transition_to_action,
|
robot_action_to_transition,
|
||||||
|
transition_to_robot_action,
|
||||||
)
|
)
|
||||||
from lerobot.record import record_loop
|
from lerobot.record import record_loop
|
||||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||||
@@ -89,7 +89,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
|
|||||||
max_ee_twist_step_rad=0.50,
|
max_ee_twist_step_rad=0.50,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
to_transition=action_to_transition,
|
to_transition=robot_action_to_transition,
|
||||||
to_output=identity_transition,
|
to_output=identity_transition,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
to_transition=identity_transition,
|
to_transition=identity_transition,
|
||||||
to_output=transition_to_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert joint observation to ee pose observation
|
# Build pipeline to convert joint observation to ee pose observation
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import time
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import RobotProcessorPipeline
|
from lerobot.processor import RobotProcessorPipeline
|
||||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
|
||||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||||
AddRobotObservationAsComplimentaryData,
|
AddRobotObservationAsComplimentaryData,
|
||||||
@@ -59,8 +59,8 @@ robot_ee_to_joints_processor = RobotProcessorPipeline(
|
|||||||
initial_guess_current_joints=False, # Because replay is open loop
|
initial_guess_current_joints=False, # Because replay is open loop
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
to_transition=action_to_transition,
|
to_transition=robot_action_to_transition,
|
||||||
to_output=transition_to_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
robot_ee_to_joints_processor.reset()
|
robot_ee_to_joints_processor.reset()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import time
|
|||||||
|
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import RobotProcessorPipeline
|
from lerobot.processor import RobotProcessorPipeline
|
||||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
|
||||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||||
AddRobotObservationAsComplimentaryData,
|
AddRobotObservationAsComplimentaryData,
|
||||||
@@ -72,8 +72,8 @@ phone_to_robot_joints_processor = RobotProcessorPipeline(
|
|||||||
speed_factor=20.0,
|
speed_factor=20.0,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
to_transition=action_to_transition,
|
to_transition=robot_action_to_transition,
|
||||||
to_output=transition_to_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|||||||
@@ -46,11 +46,13 @@ from .pipeline import (
|
|||||||
IdentityProcessorStep,
|
IdentityProcessorStep,
|
||||||
InfoProcessorStep,
|
InfoProcessorStep,
|
||||||
ObservationProcessorStep,
|
ObservationProcessorStep,
|
||||||
|
PolicyActionProcessorStep,
|
||||||
PolicyProcessorPipeline,
|
PolicyProcessorPipeline,
|
||||||
ProcessorKwargs,
|
ProcessorKwargs,
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
RewardProcessorStep,
|
RewardProcessorStep,
|
||||||
|
RobotActionProcessorStep,
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
TruncatedProcessorStep,
|
TruncatedProcessorStep,
|
||||||
)
|
)
|
||||||
@@ -81,10 +83,12 @@ __all__ = [
|
|||||||
"NormalizerProcessorStep",
|
"NormalizerProcessorStep",
|
||||||
"Numpy2TorchActionProcessorStep",
|
"Numpy2TorchActionProcessorStep",
|
||||||
"ObservationProcessorStep",
|
"ObservationProcessorStep",
|
||||||
|
"PolicyActionProcessorStep",
|
||||||
"PolicyProcessorPipeline",
|
"PolicyProcessorPipeline",
|
||||||
"ProcessorKwargs",
|
"ProcessorKwargs",
|
||||||
"ProcessorStep",
|
"ProcessorStep",
|
||||||
"ProcessorStepRegistry",
|
"ProcessorStepRegistry",
|
||||||
|
"RobotActionProcessorStep",
|
||||||
"RenameObservationsProcessorStep",
|
"RenameObservationsProcessorStep",
|
||||||
"RewardClassifierProcessorStep",
|
"RewardClassifierProcessorStep",
|
||||||
"RewardProcessorStep",
|
"RewardProcessorStep",
|
||||||
|
|||||||
@@ -27,11 +27,11 @@ from torch import Tensor
|
|||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
from .core import EnvTransition
|
from .core import EnvTransition, PolicyAction
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
ActionProcessorStep,
|
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
ObservationProcessorStep,
|
ObservationProcessorStep,
|
||||||
|
PolicyActionProcessorStep,
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
)
|
)
|
||||||
@@ -39,14 +39,14 @@ from .pipeline import (
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
class AddBatchDimensionActionStep(PolicyActionProcessorStep):
|
||||||
"""
|
"""
|
||||||
Processor step to add a batch dimension to a 1D tensor action.
|
Processor step to add a batch dimension to a 1D tensor action.
|
||||||
|
|
||||||
This is useful for creating a batch of size 1 from a single action sample.
|
This is useful for creating a batch of size 1 from a single action sample.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def action(self, action: Tensor) -> Tensor:
|
def action(self, action: PolicyAction) -> PolicyAction:
|
||||||
"""
|
"""
|
||||||
Adds a batch dimension to the action if it's a 1D tensor.
|
Adds a batch dimension to the action if it's a 1D tensor.
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ class AddBatchDimensionActionStep(ActionProcessorStep):
|
|||||||
Returns:
|
Returns:
|
||||||
The action tensor with an added batch dimension.
|
The action tensor with an added batch dimension.
|
||||||
"""
|
"""
|
||||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
if action.dim() != 1:
|
||||||
return action
|
return action
|
||||||
return action.unsqueeze(0)
|
return action.unsqueeze(0)
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import torch
|
|||||||
|
|
||||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||||
|
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
@@ -243,7 +243,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti
|
|||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
observation: dict[str, Any] | None = None,
|
observation: dict[str, Any] | None = None,
|
||||||
action: dict[str, Any] | None = None,
|
action: PolicyAction | RobotAction | None = None,
|
||||||
reward: float = 0.0,
|
reward: float = 0.0,
|
||||||
done: bool = False,
|
done: bool = False,
|
||||||
truncated: bool = False,
|
truncated: bool = False,
|
||||||
@@ -276,9 +276,9 @@ def create_transition(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition:
|
def robot_action_to_transition(action: RobotAction) -> EnvTransition:
|
||||||
"""
|
"""
|
||||||
Convert a raw action dictionary into a standardized `EnvTransition`.
|
Convert a raw robot action dictionary into a standardized `EnvTransition`.
|
||||||
|
|
||||||
The keys in the action dictionary are prefixed with "action." and stored under
|
The keys in the action dictionary are prefixed with "action." and stored under
|
||||||
the `ACTION` key in the transition. Values are converted to tensors, except for
|
the `ACTION` key in the transition. Values are converted to tensors, except for
|
||||||
@@ -315,9 +315,9 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
|||||||
return create_transition(observation={**state, **image_observations}, action={})
|
return create_transition(observation={**state, **image_observations}, action={})
|
||||||
|
|
||||||
|
|
||||||
def transition_to_action(transition: EnvTransition) -> dict[str, Any]:
|
def transition_to_robot_action(transition: EnvTransition) -> RobotAction:
|
||||||
"""
|
"""
|
||||||
Extract a raw action dictionary for a robot from an `EnvTransition`.
|
Extract a raw robot action dictionary for a robot from an `EnvTransition`.
|
||||||
|
|
||||||
This function searches for keys in the format "action.*.pos" or "action.*.vel"
|
This function searches for keys in the format "action.*.pos" or "action.*.vel"
|
||||||
and converts them into a flat dictionary suitable for sending to a robot controller.
|
and converts them into a flat dictionary suitable for sending to a robot controller.
|
||||||
@@ -460,6 +460,10 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
|||||||
if not isinstance(batch, dict):
|
if not isinstance(batch, dict):
|
||||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||||
|
|
||||||
|
action = batch.get("action")
|
||||||
|
if action is not None and not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
# Extract observation and complementary data keys.
|
# Extract observation and complementary data keys.
|
||||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
complementary_data = _extract_complementary_data(batch)
|
complementary_data = _extract_complementary_data(batch)
|
||||||
|
|||||||
@@ -17,8 +17,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypeAlias, TypedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -35,11 +36,16 @@ class TransitionKey(str, Enum):
|
|||||||
COMPLEMENTARY_DATA = "complementary_data"
|
COMPLEMENTARY_DATA = "complementary_data"
|
||||||
|
|
||||||
|
|
||||||
|
PolicyAction: TypeAlias = torch.Tensor
|
||||||
|
RobotAction: TypeAlias = dict[str, Any]
|
||||||
|
EnvAction: TypeAlias = np.ndarray
|
||||||
|
|
||||||
|
|
||||||
EnvTransition = TypedDict(
|
EnvTransition = TypedDict(
|
||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
{
|
{
|
||||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None,
|
||||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||||
|
|||||||
@@ -16,11 +16,10 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
|
|
||||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
from .core import PolicyAction, RobotAction
|
||||||
|
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||||
@@ -40,7 +39,10 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
|||||||
|
|
||||||
use_gripper: bool = True
|
use_gripper: bool = True
|
||||||
|
|
||||||
def action(self, action: Tensor) -> dict:
|
def action(self, action: PolicyAction) -> RobotAction:
|
||||||
|
if not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError("Only PolicyAction is supported for this processor")
|
||||||
|
|
||||||
if action.dim() > 1:
|
if action.dim() > 1:
|
||||||
action = action.squeeze(0)
|
action = action.squeeze(0)
|
||||||
|
|
||||||
@@ -69,7 +71,7 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||||
"""
|
"""
|
||||||
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
|
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
|
||||||
|
|
||||||
@@ -89,7 +91,7 @@ class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
|||||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||||
|
|
||||||
def action(self, action: dict) -> dict:
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||||
delta_x = action.pop("delta_x", 0.0)
|
delta_x = action.pop("delta_x", 0.0)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch
|
|||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.utils.utils import get_safe_torch_device
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
@@ -129,6 +129,10 @@ class DeviceProcessorStep(ProcessorStep):
|
|||||||
A new `EnvTransition` object with all tensors moved to the target device and dtype.
|
A new `EnvTransition` object with all tensors moved to the target device and dtype.
|
||||||
"""
|
"""
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
|
||||||
|
if action is not None and not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
simple_tensor_keys = [
|
simple_tensor_keys = [
|
||||||
TransitionKey.ACTION,
|
TransitionKey.ACTION,
|
||||||
|
|||||||
@@ -16,12 +16,10 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
|
||||||
from .converters import to_tensor
|
from .converters import to_tensor
|
||||||
|
from .core import EnvAction, PolicyAction
|
||||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
@@ -42,10 +40,10 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
|||||||
|
|
||||||
squeeze_batch_dim: bool = True
|
squeeze_batch_dim: bool = True
|
||||||
|
|
||||||
def action(self, action: torch.Tensor) -> np.ndarray:
|
def action(self, action: PolicyAction) -> EnvAction:
|
||||||
if not isinstance(action, torch.Tensor):
|
if not isinstance(action, PolicyAction):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
f"Expected PolicyAction or None, got {type(action).__name__}. "
|
||||||
"Use appropriate processor for non-tensor actions."
|
"Use appropriate processor for non-tensor actions."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,8 +78,8 @@ class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
|||||||
by a policy or model.
|
by a policy or model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
def action(self, action: EnvAction) -> PolicyAction:
|
||||||
if not isinstance(action, np.ndarray):
|
if not isinstance(action, EnvAction):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||||
"Use appropriate processor for non-tensor actions."
|
"Use appropriate processor for non-tensor actions."
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
|||||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
|
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
InfoProcessorStep,
|
InfoProcessorStep,
|
||||||
@@ -416,8 +416,8 @@ class InterventionActionProcessorStep(ProcessorStep):
|
|||||||
reward, and termination status.
|
reward, and termination status.
|
||||||
"""
|
"""
|
||||||
action = transition.get(TransitionKey.ACTION)
|
action = transition.get(TransitionKey.ACTION)
|
||||||
if action is None:
|
if not isinstance(action, PolicyAction):
|
||||||
return transition
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
# Get intervention signals from complementary data
|
# Get intervention signals from complementary data
|
||||||
info = transition.get(TransitionKey.INFO, {})
|
info = transition.get(TransitionKey.INFO, {})
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatur
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
from .converters import from_tensor_to_numpy, to_tensor
|
from .converters import from_tensor_to_numpy, to_tensor
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
@@ -345,7 +345,13 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
|||||||
|
|
||||||
# Handle action normalization.
|
# Handle action normalization.
|
||||||
action = new_transition.get(TransitionKey.ACTION)
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
if action is not None:
|
|
||||||
|
if action is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
if not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
@@ -401,7 +407,12 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
|||||||
|
|
||||||
# Handle action unnormalization.
|
# Handle action unnormalization.
|
||||||
action = new_transition.get(TransitionKey.ACTION)
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
if action is not None:
|
|
||||||
|
if action is None:
|
||||||
|
return new_transition
|
||||||
|
if not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file
|
|||||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
|
||||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||||
|
|
||||||
# Type variable for generic processor output type
|
# Type variable for generic processor output type
|
||||||
TOutput = TypeVar("TOutput")
|
TOutput = TypeVar("TOutput")
|
||||||
@@ -859,7 +859,9 @@ class ActionProcessorStep(ProcessorStep, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def action(self, action) -> Any | torch.Tensor:
|
def action(
|
||||||
|
self, action: PolicyAction | RobotAction | EnvAction
|
||||||
|
) -> PolicyAction | RobotAction | EnvAction:
|
||||||
"""Process the action component.
|
"""Process the action component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -878,6 +880,82 @@ class ActionProcessorStep(ProcessorStep, ABC):
|
|||||||
if action is None:
|
if action is None:
|
||||||
raise ValueError("ActionProcessorStep requires an action in the transition.")
|
raise ValueError("ActionProcessorStep requires an action in the transition.")
|
||||||
|
|
||||||
|
processed_action = self.action(action)
|
||||||
|
new_transition[TransitionKey.ACTION] = processed_action
|
||||||
|
raise ValueError("ActionProcessorStep requires an action in the transition.")
|
||||||
|
|
||||||
|
|
||||||
|
class RobotActionProcessorStep(ProcessorStep, ABC):
|
||||||
|
"""Base class for processors that modify only the robot action component of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `action` method to implement custom robot action processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed action
|
||||||
|
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||||
|
manipulation, focusing only on the specific robot action processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
|
"""Process the robot action component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The robot action to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed robot action
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
self._current_transition = transition.copy()
|
||||||
|
new_transition = self._current_transition
|
||||||
|
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
# NOTE: We can't use isinstance(action, RobotAction) because RobotAction is a dict[str, Any]
|
||||||
|
# because Any is generic
|
||||||
|
if not isinstance(action, dict):
|
||||||
|
raise ValueError(f"Action should be a RobotAction type got {type(action)}")
|
||||||
|
|
||||||
|
processed_action = self.action(action=action)
|
||||||
|
new_transition[TransitionKey.ACTION] = processed_action
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyActionProcessorStep(ProcessorStep, ABC):
|
||||||
|
"""Base class for processors that modify only the policy action component of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `action` method to implement custom policy action processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed action
|
||||||
|
into the transition dict, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
|
||||||
|
|
||||||
|
By inheriting from this class, you avoid writing repetitive code to handle transition dict
|
||||||
|
manipulation, focusing only on the specific policy action processing logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def action(self, action: PolicyAction) -> PolicyAction:
|
||||||
|
"""Process the policy action component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The policy action to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed policy action
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
self._current_transition = transition.copy()
|
||||||
|
new_transition = self._current_transition
|
||||||
|
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
if not isinstance(action, PolicyAction):
|
||||||
|
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||||
|
|
||||||
processed_action = self.action(action)
|
processed_action = self.action(action)
|
||||||
new_transition[TransitionKey.ACTION] = processed_action
|
new_transition[TransitionKey.ACTION] = processed_action
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|||||||
@@ -85,11 +85,11 @@ from lerobot.processor import (
|
|||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
action_to_transition,
|
|
||||||
identity_transition,
|
identity_transition,
|
||||||
observation_to_transition,
|
observation_to_transition,
|
||||||
transition_to_action,
|
robot_action_to_transition,
|
||||||
transition_to_dataset_frame,
|
transition_to_dataset_frame,
|
||||||
|
transition_to_robot_action,
|
||||||
)
|
)
|
||||||
from lerobot.processor.rename_processor import rename_stats
|
from lerobot.processor.rename_processor import rename_stats
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
@@ -255,7 +255,9 @@ def record_loop(
|
|||||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||||
teleop_action_processor
|
teleop_action_processor
|
||||||
or RobotProcessorPipeline(
|
or RobotProcessorPipeline(
|
||||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
|
steps=[IdentityProcessorStep()],
|
||||||
|
to_transition=robot_action_to_transition,
|
||||||
|
to_output=identity_transition,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
||||||
@@ -263,7 +265,7 @@ def record_loop(
|
|||||||
or RobotProcessorPipeline(
|
or RobotProcessorPipeline(
|
||||||
steps=[IdentityProcessorStep()],
|
steps=[IdentityProcessorStep()],
|
||||||
to_transition=identity_transition,
|
to_transition=identity_transition,
|
||||||
to_output=transition_to_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from pprint import pformat
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
@@ -97,8 +97,8 @@ def replay(cfg: ReplayConfig):
|
|||||||
# Initialize robot action processor with default if not provided
|
# Initialize robot action processor with default if not provided
|
||||||
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
||||||
steps=[IdentityProcessorStep()],
|
steps=[IdentityProcessorStep()],
|
||||||
to_transition=action_to_transition,
|
to_transition=robot_action_to_transition,
|
||||||
to_output=transition_to_action, # type: ignore[arg-type]
|
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset processor
|
# Reset processor
|
||||||
|
|||||||
@@ -22,21 +22,22 @@ from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeatur
|
|||||||
from lerobot.constants import OBS_STATE
|
from lerobot.constants import OBS_STATE
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ActionProcessorStep,
|
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
EnvTransition,
|
EnvTransition,
|
||||||
ObservationProcessorStep,
|
ObservationProcessorStep,
|
||||||
ProcessorStep,
|
ProcessorStep,
|
||||||
ProcessorStepRegistry,
|
ProcessorStepRegistry,
|
||||||
|
RobotActionProcessorStep,
|
||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
|
from lerobot.processor.core import RobotAction
|
||||||
from lerobot.robots.robot import Robot
|
from lerobot.robots.robot import Robot
|
||||||
from lerobot.utils.rotation import Rotation
|
from lerobot.utils.rotation import Rotation
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||||
@dataclass
|
@dataclass
|
||||||
class EEReferenceAndDelta(ActionProcessorStep):
|
class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||||
"""
|
"""
|
||||||
Computes a target end-effector pose from a relative delta command.
|
Computes a target end-effector pose from a relative delta command.
|
||||||
|
|
||||||
@@ -72,7 +73,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
|
|||||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
new_action = action.copy()
|
new_action = action.copy()
|
||||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
|
||||||
@@ -171,7 +172,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
|
|||||||
|
|
||||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||||
@dataclass
|
@dataclass
|
||||||
class EEBoundsAndSafety(ActionProcessorStep):
|
class EEBoundsAndSafety(RobotActionProcessorStep):
|
||||||
"""
|
"""
|
||||||
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
|
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
|
||||||
|
|
||||||
@@ -192,7 +193,7 @@ class EEBoundsAndSafety(ActionProcessorStep):
|
|||||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
def action(self, act: dict) -> dict:
|
def action(self, act: RobotAction) -> RobotAction:
|
||||||
x = act.get("ee.x", None)
|
x = act.get("ee.x", None)
|
||||||
y = act.get("ee.y", None)
|
y = act.get("ee.y", None)
|
||||||
z = act.get("ee.z", None)
|
z = act.get("ee.z", None)
|
||||||
@@ -266,6 +267,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||||
|
|
||||||
|
if not isinstance(act, dict):
|
||||||
|
raise ValueError(f"Action should be a RobotAction type got {type(act)}")
|
||||||
|
|
||||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
x = act.get("ee.x", None)
|
x = act.get("ee.x", None)
|
||||||
@@ -361,6 +366,9 @@ class GripperVelocityToJoint(ProcessorStep):
|
|||||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
|
if not isinstance(act, dict):
|
||||||
|
raise ValueError(f"Action should be a RobotAction type got {type(act)}")
|
||||||
|
|
||||||
if "gripper" not in act:
|
if "gripper" not in act:
|
||||||
raise ValueError("Required action key 'gripper' not found in transition")
|
raise ValueError("Required action key 'gripper' not found in transition")
|
||||||
|
|
||||||
|
|||||||
@@ -64,10 +64,10 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey
|
from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
action_to_transition,
|
|
||||||
identity_transition,
|
identity_transition,
|
||||||
observation_to_transition,
|
observation_to_transition,
|
||||||
transition_to_action,
|
robot_action_to_transition,
|
||||||
|
transition_to_robot_action,
|
||||||
)
|
)
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
@@ -140,7 +140,9 @@ def teleop_loop(
|
|||||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||||
teleop_action_processor
|
teleop_action_processor
|
||||||
or RobotProcessorPipeline(
|
or RobotProcessorPipeline(
|
||||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
|
steps=[IdentityProcessorStep()],
|
||||||
|
to_transition=robot_action_to_transition,
|
||||||
|
to_output=identity_transition,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
||||||
@@ -148,7 +150,7 @@ def teleop_loop(
|
|||||||
or RobotProcessorPipeline(
|
or RobotProcessorPipeline(
|
||||||
steps=[IdentityProcessorStep()],
|
steps=[IdentityProcessorStep()],
|
||||||
to_transition=identity_transition,
|
to_transition=identity_transition,
|
||||||
to_output=transition_to_action, # type: ignore[arg-type]
|
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def test_batch_to_transition_observation_grouping():
|
|||||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||||
"observation.state": [1, 2, 3, 4],
|
"observation.state": [1, 2, 3, 4],
|
||||||
"action": "action_data",
|
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||||
"next.reward": 1.5,
|
"next.reward": 1.5,
|
||||||
"next.done": True,
|
"next.done": True,
|
||||||
"next.truncated": False,
|
"next.truncated": False,
|
||||||
@@ -74,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
|
|||||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||||
|
|
||||||
# Check other fields
|
# Check other fields
|
||||||
assert transition[TransitionKey.ACTION] == "action_data"
|
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||||
assert transition[TransitionKey.REWARD] == 1.5
|
assert transition[TransitionKey.REWARD] == 1.5
|
||||||
assert transition[TransitionKey.DONE]
|
assert transition[TransitionKey.DONE]
|
||||||
assert not transition[TransitionKey.TRUNCATED]
|
assert not transition[TransitionKey.TRUNCATED]
|
||||||
@@ -123,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
|
|||||||
def test_no_observation_keys():
|
def test_no_observation_keys():
|
||||||
"""Test behavior when there are no observation.* keys."""
|
"""Test behavior when there are no observation.* keys."""
|
||||||
batch = {
|
batch = {
|
||||||
"action": "action_data",
|
"action": torch.tensor([1.0, 2.0]),
|
||||||
"next.reward": 2.0,
|
"next.reward": 2.0,
|
||||||
"next.done": False,
|
"next.done": False,
|
||||||
"next.truncated": True,
|
"next.truncated": True,
|
||||||
@@ -136,7 +136,7 @@ def test_no_observation_keys():
|
|||||||
assert transition[TransitionKey.OBSERVATION] is None
|
assert transition[TransitionKey.OBSERVATION] is None
|
||||||
|
|
||||||
# Check other fields
|
# Check other fields
|
||||||
assert transition[TransitionKey.ACTION] == "action_data"
|
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0]))
|
||||||
assert transition[TransitionKey.REWARD] == 2.0
|
assert transition[TransitionKey.REWARD] == 2.0
|
||||||
assert not transition[TransitionKey.DONE]
|
assert not transition[TransitionKey.DONE]
|
||||||
assert transition[TransitionKey.TRUNCATED]
|
assert transition[TransitionKey.TRUNCATED]
|
||||||
@@ -144,7 +144,7 @@ def test_no_observation_keys():
|
|||||||
|
|
||||||
# Round trip should work
|
# Round trip should work
|
||||||
reconstructed_batch = transition_to_batch(transition)
|
reconstructed_batch = transition_to_batch(transition)
|
||||||
assert reconstructed_batch["action"] == "action_data"
|
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
|
||||||
assert reconstructed_batch["next.reward"] == 2.0
|
assert reconstructed_batch["next.reward"] == 2.0
|
||||||
assert not reconstructed_batch["next.done"]
|
assert not reconstructed_batch["next.done"]
|
||||||
assert reconstructed_batch["next.truncated"]
|
assert reconstructed_batch["next.truncated"]
|
||||||
@@ -153,13 +153,13 @@ def test_no_observation_keys():
|
|||||||
|
|
||||||
def test_minimal_batch():
|
def test_minimal_batch():
|
||||||
"""Test with minimal batch containing only observation.* and action."""
|
"""Test with minimal batch containing only observation.* and action."""
|
||||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||||
|
|
||||||
transition = batch_to_transition(batch)
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
# Check observation
|
# Check observation
|
||||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||||
|
|
||||||
# Check defaults
|
# Check defaults
|
||||||
assert transition[TransitionKey.REWARD] == 0.0
|
assert transition[TransitionKey.REWARD] == 0.0
|
||||||
@@ -171,7 +171,7 @@ def test_minimal_batch():
|
|||||||
# Round trip
|
# Round trip
|
||||||
reconstructed_batch = transition_to_batch(transition)
|
reconstructed_batch = transition_to_batch(transition)
|
||||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||||
assert reconstructed_batch["action"] == "minimal_action"
|
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||||
assert reconstructed_batch["next.reward"] == 0.0
|
assert reconstructed_batch["next.reward"] == 0.0
|
||||||
assert not reconstructed_batch["next.done"]
|
assert not reconstructed_batch["next.done"]
|
||||||
assert not reconstructed_batch["next.truncated"]
|
assert not reconstructed_batch["next.truncated"]
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def test_state_1d_to_2d():
|
|||||||
# Test observation.state
|
# Test observation.state
|
||||||
state_1d = torch.randn(7)
|
state_1d = torch.randn(7)
|
||||||
observation = {OBS_STATE: state_1d}
|
observation = {OBS_STATE: state_1d}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ def test_env_state_1d_to_2d():
|
|||||||
# Test observation.environment_state
|
# Test observation.environment_state
|
||||||
env_state_1d = torch.randn(10)
|
env_state_1d = torch.randn(10)
|
||||||
observation = {OBS_ENV_STATE: env_state_1d}
|
observation = {OBS_ENV_STATE: env_state_1d}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ def test_image_3d_to_4d():
|
|||||||
# Test observation.image
|
# Test observation.image
|
||||||
image_3d = torch.randn(224, 224, 3)
|
image_3d = torch.randn(224, 224, 3)
|
||||||
observation = {OBS_IMAGE: image_3d}
|
observation = {OBS_IMAGE: image_3d}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d():
|
|||||||
f"{OBS_IMAGES}.camera1": image1_3d,
|
f"{OBS_IMAGES}.camera1": image1_3d,
|
||||||
f"{OBS_IMAGES}.camera2": image2_3d,
|
f"{OBS_IMAGES}.camera2": image2_3d,
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged():
|
|||||||
OBS_ENV_STATE: env_state_2d,
|
OBS_ENV_STATE: env_state_2d,
|
||||||
OBS_IMAGE: image_4d,
|
OBS_IMAGE: image_4d,
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged():
|
|||||||
OBS_STATE: state_3d,
|
OBS_STATE: state_3d,
|
||||||
OBS_IMAGE: image_5d,
|
OBS_IMAGE: image_5d,
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged():
|
|||||||
"custom_key": 42, # Integer
|
"custom_key": 42, # Integer
|
||||||
"another_key": {"nested": "dict"}, # Dict
|
"another_key": {"nested": "dict"}, # Dict
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ def test_none_observation():
|
|||||||
"""Test processor handles None observation gracefully."""
|
"""Test processor handles None observation gracefully."""
|
||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
transition = create_transition(observation={}, action={})
|
transition = create_transition(observation={}, action=torch.empty(0))
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
assert result[TransitionKey.OBSERVATION] == {}
|
assert result[TransitionKey.OBSERVATION] == {}
|
||||||
@@ -191,7 +191,7 @@ def test_empty_observation():
|
|||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
observation = {}
|
observation = {}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ def test_mixed_observation():
|
|||||||
"other_tensor": other_tensor,
|
"other_tensor": other_tensor,
|
||||||
"non_tensor": "string_value",
|
"non_tensor": "string_value",
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
@@ -243,7 +243,7 @@ def test_integration_with_robot_processor():
|
|||||||
OBS_STATE: torch.randn(7),
|
OBS_STATE: torch.randn(7),
|
||||||
OBS_IMAGE: torch.randn(224, 224, 3),
|
OBS_IMAGE: torch.randn(224, 224, 3),
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = pipeline(transition)
|
result = pipeline(transition)
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
@@ -299,7 +299,7 @@ def test_save_and_load_pretrained():
|
|||||||
|
|
||||||
# Test functionality of loaded processor
|
# Test functionality of loaded processor
|
||||||
observation = {OBS_STATE: torch.randn(5)}
|
observation = {OBS_STATE: torch.randn(5)}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = loaded_pipeline(transition)
|
result = loaded_pipeline(transition)
|
||||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
||||||
@@ -333,7 +333,7 @@ def test_registry_based_save_load():
|
|||||||
OBS_STATE: torch.randn(3),
|
OBS_STATE: torch.randn(3),
|
||||||
OBS_IMAGE: torch.randn(100, 100, 3),
|
OBS_IMAGE: torch.randn(100, 100, 3),
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = loaded_pipeline(transition)
|
result = loaded_pipeline(transition)
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
@@ -355,7 +355,7 @@ def test_device_compatibility():
|
|||||||
OBS_STATE: state_1d,
|
OBS_STATE: state_1d,
|
||||||
OBS_IMAGE: image_3d,
|
OBS_IMAGE: image_3d,
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
@@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors():
|
|||||||
OBS_STATE: scalar_tensor,
|
OBS_STATE: scalar_tensor,
|
||||||
"scalar_value": scalar_tensor,
|
"scalar_value": scalar_tensor,
|
||||||
}
|
}
|
||||||
transition = create_transition(observation=observation, action={})
|
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_obs = result[TransitionKey.OBSERVATION]
|
processed_obs = result[TransitionKey.OBSERVATION]
|
||||||
@@ -490,42 +490,43 @@ def test_action_scalar_tensor():
|
|||||||
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
|
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
|
||||||
|
|
||||||
|
|
||||||
def test_action_non_tensor():
|
def test_action_non_tensor_raises_error():
|
||||||
"""Test that non-tensor actions remain unchanged."""
|
"""Test that non-tensor actions raise ValueError for PolicyAction processors."""
|
||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
# List action
|
# List action should raise error
|
||||||
action_list = [0.1, 0.2, 0.3, 0.4]
|
action_list = [0.1, 0.2, 0.3, 0.4]
|
||||||
transition = create_transition(action=action_list, observation={})
|
transition = create_transition(action=action_list)
|
||||||
result = processor(transition)
|
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||||
assert result[TransitionKey.ACTION] == action_list
|
processor(transition)
|
||||||
|
|
||||||
# Numpy array action (as Python object, not converted)
|
# Numpy array action should raise error
|
||||||
action_numpy = np.array([1, 2, 3, 4])
|
action_numpy = np.array([1, 2, 3, 4])
|
||||||
transition = create_transition(action=action_numpy, observation={})
|
transition = create_transition(action=action_numpy)
|
||||||
result = processor(transition)
|
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||||
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
processor(transition)
|
||||||
|
|
||||||
# String action (edge case)
|
# String action should raise error
|
||||||
action_string = "forward"
|
action_string = "forward"
|
||||||
transition = create_transition(action=action_string, observation={})
|
transition = create_transition(action=action_string)
|
||||||
result = processor(transition)
|
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||||
assert result[TransitionKey.ACTION] == action_string
|
processor(transition)
|
||||||
|
|
||||||
# Dict action (structured action)
|
# Dict action should raise error
|
||||||
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
|
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
|
||||||
transition = create_transition(action=action_dict, observation={})
|
transition = create_transition(action=action_dict)
|
||||||
result = processor(transition)
|
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||||
assert result[TransitionKey.ACTION] == action_dict
|
processor(transition)
|
||||||
|
|
||||||
|
|
||||||
def test_action_none():
|
def test_action_none():
|
||||||
"""Test that None action is handled correctly."""
|
"""Test that empty action tensor is handled correctly."""
|
||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
transition = create_transition(action={}, observation={})
|
transition = create_transition(action=torch.empty(0), observation={})
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
assert result[TransitionKey.ACTION] == {}
|
# Empty 1D tensor becomes empty 2D tensor with batch dimension
|
||||||
|
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||||
|
|
||||||
|
|
||||||
def test_action_with_observation():
|
def test_action_with_observation():
|
||||||
@@ -630,7 +631,9 @@ def test_task_string_to_list():
|
|||||||
|
|
||||||
# Create complementary data with string task
|
# Create complementary data with string task
|
||||||
complementary_data = {"task": "pick_cube"}
|
complementary_data = {"task": "pick_cube"}
|
||||||
transition = create_transition(action={}, observation={}, complementary_data=complementary_data)
|
transition = create_transition(
|
||||||
|
action=torch.empty(0), observation={}, complementary_data=complementary_data
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -647,14 +650,18 @@ def test_task_string_validation():
|
|||||||
|
|
||||||
# Valid string task - should be converted to list
|
# Valid string task - should be converted to list
|
||||||
complementary_data = {"task": "valid_task"}
|
complementary_data = {"task": "valid_task"}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
assert processed_comp_data["task"] == ["valid_task"]
|
assert processed_comp_data["task"] == ["valid_task"]
|
||||||
|
|
||||||
# Valid list of strings - should remain unchanged
|
# Valid list of strings - should remain unchanged
|
||||||
complementary_data = {"task": ["task1", "task2"]}
|
complementary_data = {"task": ["task1", "task2"]}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
assert processed_comp_data["task"] == ["task1", "task2"]
|
assert processed_comp_data["task"] == ["task1", "task2"]
|
||||||
@@ -676,7 +683,9 @@ def test_task_list_of_strings():
|
|||||||
|
|
||||||
for task_list in test_lists:
|
for task_list in test_lists:
|
||||||
complementary_data = {"task": task_list}
|
complementary_data = {"task": task_list}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -690,7 +699,7 @@ def test_complementary_data_none():
|
|||||||
"""Test processor handles None complementary_data gracefully."""
|
"""Test processor handles None complementary_data gracefully."""
|
||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
transition = create_transition(complementary_data=None, action={}, observation={})
|
transition = create_transition(complementary_data=None, action=torch.empty(0), observation={})
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||||
@@ -701,7 +710,9 @@ def test_complementary_data_empty():
|
|||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
complementary_data = {}
|
complementary_data = {}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -717,7 +728,9 @@ def test_complementary_data_no_task():
|
|||||||
"timestamp": 1234567890.0,
|
"timestamp": 1234567890.0,
|
||||||
"extra_info": "some data",
|
"extra_info": "some data",
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -736,7 +749,9 @@ def test_complementary_data_mixed():
|
|||||||
"difficulty": "hard",
|
"difficulty": "hard",
|
||||||
"metadata": {"scene": "kitchen"},
|
"metadata": {"scene": "kitchen"},
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -803,7 +818,9 @@ def test_task_comprehensive_string_cases():
|
|||||||
# Test that all string tasks get properly batched
|
# Test that all string tasks get properly batched
|
||||||
for task in string_tasks:
|
for task in string_tasks:
|
||||||
complementary_data = {"task": task}
|
complementary_data = {"task": task}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -825,7 +842,9 @@ def test_task_comprehensive_string_cases():
|
|||||||
|
|
||||||
for task_list in list_tasks:
|
for task_list in list_tasks:
|
||||||
complementary_data = {"task": task_list}
|
complementary_data = {"task": task_list}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -845,7 +864,9 @@ def test_task_preserves_other_keys():
|
|||||||
"config": {"speed": "slow", "precision": "high"},
|
"config": {"speed": "slow", "precision": "high"},
|
||||||
"metrics": [1.0, 2.0, 3.0],
|
"metrics": [1.0, 2.0, 3.0],
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -869,7 +890,9 @@ def test_index_scalar_to_1d():
|
|||||||
# Create 0D index tensor (scalar)
|
# Create 0D index tensor (scalar)
|
||||||
index_0d = torch.tensor(42, dtype=torch.int64)
|
index_0d = torch.tensor(42, dtype=torch.int64)
|
||||||
complementary_data = {"index": index_0d}
|
complementary_data = {"index": index_0d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -886,7 +909,9 @@ def test_task_index_scalar_to_1d():
|
|||||||
# Create 0D task_index tensor (scalar)
|
# Create 0D task_index tensor (scalar)
|
||||||
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
||||||
complementary_data = {"task_index": task_index_0d}
|
complementary_data = {"task_index": task_index_0d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -908,7 +933,9 @@ def test_index_and_task_index_together():
|
|||||||
"task_index": task_index_0d,
|
"task_index": task_index_0d,
|
||||||
"task": "pick_object",
|
"task": "pick_object",
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -936,13 +963,17 @@ def test_index_already_batched():
|
|||||||
|
|
||||||
# Test 1D (already batched)
|
# Test 1D (already batched)
|
||||||
complementary_data = {"index": index_1d}
|
complementary_data = {"index": index_1d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
|
||||||
|
|
||||||
# Test 2D
|
# Test 2D
|
||||||
complementary_data = {"index": index_2d}
|
complementary_data = {"index": index_2d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
|
||||||
|
|
||||||
@@ -957,13 +988,17 @@ def test_task_index_already_batched():
|
|||||||
|
|
||||||
# Test 1D (already batched)
|
# Test 1D (already batched)
|
||||||
complementary_data = {"task_index": task_index_1d}
|
complementary_data = {"task_index": task_index_1d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
|
||||||
|
|
||||||
# Test 2D
|
# Test 2D
|
||||||
complementary_data = {"task_index": task_index_2d}
|
complementary_data = {"task_index": task_index_2d}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
|
||||||
|
|
||||||
@@ -976,7 +1011,9 @@ def test_index_non_tensor_unchanged():
|
|||||||
"index": 42, # Plain int, not tensor
|
"index": 42, # Plain int, not tensor
|
||||||
"task_index": [1, 2, 3], # List, not tensor
|
"task_index": [1, 2, 3], # List, not tensor
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -999,7 +1036,9 @@ def test_index_dtype_preservation():
|
|||||||
"index": index_0d,
|
"index": index_0d,
|
||||||
"task_index": task_index_0d,
|
"task_index": task_index_0d,
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -1062,7 +1101,9 @@ def test_index_device_compatibility():
|
|||||||
"index": index_0d,
|
"index": index_0d,
|
||||||
"task_index": task_index_0d,
|
"task_index": task_index_0d,
|
||||||
}
|
}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
@@ -1081,7 +1122,9 @@ def test_empty_index_tensor():
|
|||||||
# Empty 0D tensor doesn't make sense, but test empty 1D
|
# Empty 0D tensor doesn't make sense, but test empty 1D
|
||||||
index_empty = torch.tensor([], dtype=torch.int64)
|
index_empty = torch.tensor([], dtype=torch.int64)
|
||||||
complementary_data = {"index": index_empty}
|
complementary_data = {"index": index_empty}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
result = processor(transition)
|
result = processor(transition)
|
||||||
|
|
||||||
@@ -1116,7 +1159,9 @@ def test_task_processing_creates_new_transition():
|
|||||||
processor = AddBatchDimensionProcessorStep()
|
processor = AddBatchDimensionProcessorStep()
|
||||||
|
|
||||||
complementary_data = {"task": "sort_objects"}
|
complementary_data = {"task": "sort_objects"}
|
||||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
transition = create_transition(
|
||||||
|
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||||
|
)
|
||||||
|
|
||||||
# Store reference to original transition and complementary_data
|
# Store reference to original transition and complementary_data
|
||||||
original_transition = transition
|
original_transition = transition
|
||||||
|
|||||||
@@ -329,14 +329,14 @@ def test_min_max_unnormalization(action_stats_min_max):
|
|||||||
assert torch.allclose(unnormalized_action, expected)
|
assert torch.allclose(unnormalized_action, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_numpy_action_input(action_stats_mean_std):
|
def test_tensor_action_input(action_stats_mean_std):
|
||||||
features = _create_action_features()
|
features = _create_action_features()
|
||||||
norm_map = _create_action_norm_map_mean_std()
|
norm_map = _create_action_norm_map_mean_std()
|
||||||
unnormalizer = UnnormalizerProcessorStep(
|
unnormalizer = UnnormalizerProcessorStep(
|
||||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||||
)
|
)
|
||||||
|
|
||||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32)
|
||||||
transition = create_transition(action=normalized_action)
|
transition = create_transition(action=normalized_action)
|
||||||
|
|
||||||
unnormalized_transition = unnormalizer(transition)
|
unnormalized_transition = unnormalizer(transition)
|
||||||
|
|||||||
@@ -371,12 +371,12 @@ def test_sac_processor_edge_cases():
|
|||||||
assert processed[TransitionKey.OBSERVATION] == {}
|
assert processed[TransitionKey.OBSERVATION] == {}
|
||||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||||
|
|
||||||
# Test with None action
|
# Test with zero action (representing "null" action)
|
||||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action={})
|
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5))
|
||||||
processed = preprocessor(transition)
|
processed = preprocessor(transition)
|
||||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||||
# When action is None, it may still be present with None value
|
# Action should be present and batched, even if it's zeros
|
||||||
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
|
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
|||||||
Reference in New Issue
Block a user