From 376a6457cf5245751c889fc5b7130ea25a8bbe9f Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 11 Sep 2025 13:36:04 +0200 Subject: [PATCH] feat(processor): enhance type safety with generic DataProcessorPipeline for policy and robot pipelines (#1915) * refactor(processor): enhance type annotations for processors in record, replay, teleoperate, and control utils - Updated type annotations for preprocessor and postprocessor parameters in record_loop and predict_action functions to specify the expected dictionary types. - Adjusted robot_action_processor type in ReplayConfig and TeleoperateConfig to improve clarity and maintainability. - Ensured consistency in type definitions across multiple files, enhancing overall code readability. * refactor(processor): enhance type annotations for RobotProcessorPipeline in various files - Updated type annotations for RobotProcessorPipeline instances in evaluate.py, record.py, replay.py, teleoperate.py, and other related files to specify input and output types more clearly. - Introduced new type conversions for PolicyAction and EnvTransition to improve type safety and maintainability across the processing pipelines. - Ensured consistency in type definitions, enhancing overall code readability and reducing potential runtime errors. * refactor(processor): update transition handling in processors to use transition_to_batch - Replaced direct transition handling with transition_to_batch in various processor tests and implementations to ensure consistent batching of input data. - Updated assertions in tests to reflect changes in data structure, enhancing clarity and maintainability. - Improved overall code readability by standardizing the way transitions are processed across different processor types. * refactor(tests): standardize transition key usage in processor tests - Updated assertions in processor test files to utilize the TransitionKey for action references, enhancing consistency across tests. - Replaced direct string references with TransitionKey constants for improved readability and maintainability. - Ensured that all relevant tests reflect these changes, contributing to a more uniform approach in handling transitions. --- examples/phone_to_so100/evaluate.py | 7 +- examples/phone_to_so100/record.py | 8 +- examples/phone_to_so100/replay.py | 3 +- examples/phone_to_so100/teleoperate.py | 3 +- src/lerobot/policies/act/processor_act.py | 30 ++- .../policies/diffusion/processor_diffusion.py | 30 ++- src/lerobot/policies/factory.py | 48 ++--- src/lerobot/policies/pi0/processor_pi0.py | 23 ++- .../policies/pi0fast/processor_pi0fast.py | 24 +-- src/lerobot/policies/sac/processor_sac.py | 27 ++- .../sac/reward_model/processor_classifier.py | 21 +- .../policies/smolvla/processor_smolvla.py | 26 ++- src/lerobot/policies/tdmpc/processor_tdmpc.py | 26 ++- src/lerobot/policies/vqbet/processor_vqbet.py | 26 ++- src/lerobot/processor/converters.py | 19 ++ src/lerobot/processor/pipeline.py | 36 ++-- src/lerobot/record.py | 26 +-- src/lerobot/replay.py | 7 +- src/lerobot/scripts/eval.py | 22 +-- src/lerobot/teleoperate.py | 33 ++-- src/lerobot/utils/control_utils.py | 10 +- tests/processor/test_act_processor.py | 119 ++++++----- tests/processor/test_classifier_processor.py | 125 ++++++------ tests/processor/test_diffusion_processor.py | 187 +++++++----------- tests/processor/test_pi0_processor.py | 59 +++--- tests/processor/test_sac_processor.py | 146 ++++++-------- tests/processor/test_smolvla_processor.py | 59 +++--- tests/processor/test_tdmpc_processor.py | 158 +++++++-------- tests/processor/test_vqbet_processor.py | 149 +++++++------- 29 files changed, 671 insertions(+), 786 deletions(-) diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index de8dd9073..bd7272b0b 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features @@ -27,6 +29,7 @@ from lerobot.processor.converters import ( observation_to_transition, transition_to_robot_action, ) +from lerobot.processor.core import EnvTransition, RobotAction from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( @@ -66,7 +69,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints_processor = RobotProcessorPipeline( +robot_ee_to_joints_processor = RobotProcessorPipeline[EnvTransition, RobotAction]( steps=[ AddRobotObservationAsComplimentaryData(robot=robot), InverseKinematicsEEToJoints( @@ -80,7 +83,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline( ) # Build pipeline to convert joint observation to ee pose observation -robot_joints_to_ee_pose_processor = RobotProcessorPipeline( +robot_joints_to_ee_pose_processor = RobotProcessorPipeline[dict[str, Any], EnvTransition]( steps=[ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) ], diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index c47667f4f..f310a12d1 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -27,6 +28,7 @@ from lerobot.processor.converters import ( robot_action_to_transition, transition_to_robot_action, ) +from lerobot.processor.core import EnvTransition, RobotAction from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( @@ -74,7 +76,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert phone action to ee pose action -phone_to_robot_ee_pose_processor = RobotProcessorPipeline( +phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, EnvTransition]( steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), AddRobotObservationAsComplimentaryData(robot=robot), @@ -94,7 +96,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints_processor = RobotProcessorPipeline( +robot_ee_to_joints_processor = RobotProcessorPipeline[EnvTransition, RobotAction]( steps=[ InverseKinematicsEEToJoints( kinematics=kinematics_solver, @@ -111,7 +113,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline( ) # Build pipeline to convert joint observation to ee pose observation -robot_joints_to_ee_pose = RobotProcessorPipeline( +robot_joints_to_ee_pose = RobotProcessorPipeline[dict[str, Any], EnvTransition]( steps=[ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) ], diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 180fdfb3f..4f4dcc62f 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -21,6 +21,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action +from lerobot.processor.core import RobotAction from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( AddRobotObservationAsComplimentaryData, @@ -50,7 +51,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints_processor = RobotProcessorPipeline( +robot_ee_to_joints_processor = RobotProcessorPipeline[RobotAction, RobotAction]( steps=[ AddRobotObservationAsComplimentaryData(robot=robot), InverseKinematicsEEToJoints( diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index f2125544a..b91842d5a 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -18,6 +18,7 @@ import time from lerobot.model.kinematics import RobotKinematics from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action +from lerobot.processor.core import RobotAction from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( AddRobotObservationAsComplimentaryData, @@ -49,7 +50,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert phone action to ee pose action to joint action -phone_to_robot_joints_processor = RobotProcessorPipeline( +phone_to_robot_joints_processor = RobotProcessorPipeline[RobotAction, RobotAction]( steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), AddRobotObservationAsComplimentaryData(robot=robot), diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 76f362464..129d2997f 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -22,18 +24,20 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_act_pre_post_processors( config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """Creates the pre- and post-processing pipelines for the ACT policy. The pre-processing pipeline handles normalization, batching, and device placement for the model inputs. @@ -43,19 +47,11 @@ def make_act_pre_post_processors( config (ACTConfig): The ACT policy configuration object. dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset statistics (e.g., mean and std) used for normalization. Defaults to None. - preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the - preprocessor pipeline's constructor. Defaults to None. - postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the - postprocessor pipeline's constructor. Defaults to None. Returns: - tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the + tuple[PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction]]: A tuple containing the pre-processor pipeline and the post-processor pipeline. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), @@ -76,14 +72,14 @@ def make_act_pre_post_processors( ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 324ffea42..e9ada252b 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -23,23 +25,25 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_diffusion_pre_post_processors( config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for a diffusion policy. The pre-processing pipeline prepares the input data for the model by: - 1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`). + 1. Renaming features. 2. Normalizing the input and output features based on dataset statistics. 3. Adding a batch dimension. 4. Moving the data to the specified device. @@ -53,18 +57,10 @@ def make_diffusion_pre_post_processors( containing feature definitions, normalization mappings, and device information. dataset_stats: A dictionary of statistics used for normalization. Defaults to None. - preprocessor_kwargs: Additional keyword arguments - for the pre-processor pipeline. Defaults to an empty dictionary. - postprocessor_kwargs: Additional keyword arguments - for the post-processor pipeline. Defaults to an empty dictionary. Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), @@ -83,14 +79,14 @@ def make_diffusion_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 520f2342d..8d94f5837 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -39,7 +39,14 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs +from lerobot.processor import PolicyProcessorPipeline +from lerobot.processor.converters import ( + batch_to_transition, + policy_action_to_transition, + transition_to_batch, + transition_to_policy_action, +) +from lerobot.processor.core import PolicyAction def get_policy_class(name: str) -> type[PreTrainedPolicy]: @@ -153,8 +160,6 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_overrides: A dictionary of overrides for the preprocessor configuration. postprocessor_overrides: A dictionary of overrides for the postprocessor configuration. dataset_stats: Dataset statistics for normalization. - preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`. - postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`. """ preprocessor_config_filename: str | None @@ -162,15 +167,16 @@ class ProcessorConfigKwargs(TypedDict, total=False): preprocessor_overrides: dict[str, Any] | None postprocessor_overrides: dict[str, Any] | None dataset_stats: dict[str, dict[str, torch.Tensor]] | None - preprocessor_kwargs: ProcessorKwargs | None - postprocessor_kwargs: ProcessorKwargs | None def make_pre_post_processors( policy_cfg: PreTrainedConfig, pretrained_path: str | None = None, **kwargs: Unpack[ProcessorConfigKwargs], -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Create or load pre- and post-processor pipelines for a given policy. @@ -194,10 +200,6 @@ def make_pre_post_processors( policy configuration type. """ if pretrained_path: - # Extract preprocessor and postprocessor kwargs - preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {}) - postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {}) - return ( PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, @@ -205,8 +207,8 @@ def make_pre_post_processors( "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" ), overrides=kwargs.get("preprocessor_overrides", {}), - to_transition=preprocessor_kwargs.get("to_transition"), - to_output=preprocessor_kwargs.get("to_output"), + to_transition=batch_to_transition, + to_output=transition_to_batch, ), PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, @@ -214,8 +216,8 @@ def make_pre_post_processors( "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" ), overrides=kwargs.get("postprocessor_overrides", {}), - to_transition=postprocessor_kwargs.get("to_transition"), - to_output=postprocessor_kwargs.get("to_output"), + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) @@ -226,8 +228,6 @@ def make_pre_post_processors( processors = make_tdmpc_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, DiffusionConfig): @@ -236,8 +236,6 @@ def make_pre_post_processors( processors = make_diffusion_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, ACTConfig): @@ -246,8 +244,6 @@ def make_pre_post_processors( processors = make_act_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, VQBeTConfig): @@ -256,8 +252,6 @@ def make_pre_post_processors( processors = make_vqbet_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, PI0Config): @@ -266,8 +260,6 @@ def make_pre_post_processors( processors = make_pi0_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, PI0FASTConfig): @@ -276,8 +268,6 @@ def make_pre_post_processors( processors = make_pi0fast_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, SACConfig): @@ -286,8 +276,6 @@ def make_pre_post_processors( processors = make_sac_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, RewardClassifierConfig): @@ -296,8 +284,6 @@ def make_pre_post_processors( processors = make_classifier_processor( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) elif isinstance(policy_cfg, SmolVLAConfig): @@ -306,8 +292,6 @@ def make_pre_post_processors( processors = make_smolvla_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), - preprocessor_kwargs=kwargs.get("preprocessor_kwargs"), - postprocessor_kwargs=kwargs.get("postprocessor_kwargs"), ) else: diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 417105208..f6470d8cf 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any import torch @@ -26,13 +27,14 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction @ProcessorStepRegistry.register(name="pi0_new_line_processor") @@ -95,9 +97,10 @@ class Pi0NewLineProcessor(ComplementaryDataProcessorStep): def make_pi0_pre_post_processors( config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the PI0 policy. @@ -122,10 +125,6 @@ def make_pi0_pre_post_processors( Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} # Add remaining processors input_steps: list[ProcessorStep] = [ @@ -154,14 +153,14 @@ def make_pi0_pre_post_processors( ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index e6ed9a4d2..de4443413 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -23,18 +25,20 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_pi0fast_pre_post_processors( config: PI0FASTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the PI0Fast policy. @@ -57,10 +61,6 @@ def make_pi0fast_pre_post_processors( Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one @@ -79,14 +79,14 @@ def make_pi0fast_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 0caeec32a..c0cd8f751 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -24,18 +26,20 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_sac_pre_post_processors( config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the SAC policy. @@ -52,17 +56,12 @@ def make_sac_pre_post_processors( Args: config: The configuration object for the SAC policy. dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} + # Add remaining processors input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), @@ -80,14 +79,14 @@ def make_sac_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py index 571ccdfd9..a0c6bd503 100644 --- a/src/lerobot/policies/sac/reward_model/processor_classifier.py +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -13,6 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Any + import torch from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -21,16 +24,18 @@ from lerobot.processor import ( IdentityProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_classifier_processor( config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the reward classifier. @@ -51,10 +56,6 @@ def make_classifier_processor( Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ NormalizerProcessorStep( @@ -71,11 +72,11 @@ def make_classifier_processor( PolicyProcessorPipeline( steps=input_steps, name="classifier_preprocessor", - **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, name="classifier_postprocessor", - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 92002ebad..90ac0fa9a 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature @@ -25,20 +27,22 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, ProcessorStepRegistry, RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_smolvla_pre_post_processors( config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the SmolVLA policy. @@ -57,16 +61,10 @@ def make_smolvla_pre_post_processors( Args: config: The configuration object for the SmolVLA policy. dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one @@ -92,15 +90,15 @@ def make_smolvla_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 76e7b7ab1..51f97bd6e 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -14,6 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -23,18 +25,20 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_tdmpc_pre_post_processors( config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the TDMPC policy. @@ -51,16 +55,10 @@ def make_tdmpc_pre_post_processors( Args: config: The configuration object for the TDMPC policy. dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), @@ -79,14 +77,14 @@ def make_tdmpc_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index 41a9d66f8..bd9f7102a 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -15,6 +15,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -24,18 +26,20 @@ from lerobot.processor import ( DeviceProcessorStep, NormalizerProcessorStep, PolicyProcessorPipeline, - ProcessorKwargs, RenameObservationsProcessorStep, UnnormalizerProcessorStep, ) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import PolicyAction def make_vqbet_pre_post_processors( config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, - preprocessor_kwargs: ProcessorKwargs | None = None, - postprocessor_kwargs: ProcessorKwargs | None = None, -) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: """ Constructs pre-processor and post-processor pipelines for the VQ-BeT policy. @@ -52,16 +56,10 @@ def make_vqbet_pre_post_processors( Args: config: The configuration object for the VQ-BeT policy. dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. Returns: A tuple containing the configured pre-processor and post-processor pipelines. """ - if preprocessor_kwargs is None: - preprocessor_kwargs = {} - if postprocessor_kwargs is None: - postprocessor_kwargs = {} input_steps = [ RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys @@ -80,14 +78,14 @@ def make_vqbet_pre_post_processors( ), ] return ( - PolicyProcessorPipeline( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, name=POLICY_PREPROCESSOR_DEFAULT_NAME, - **preprocessor_kwargs, ), - PolicyProcessorPipeline( + PolicyProcessorPipeline[PolicyAction, PolicyAction]( steps=output_steps, name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - **postprocessor_kwargs, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, ), ) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 8456cad11..3ae846056 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -331,6 +331,25 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction: return transition.get(TransitionKey.ACTION) +def transition_to_policy_action(transition: EnvTransition) -> PolicyAction: + """ + Convert an `EnvTransition` to a `PolicyAction`. + """ + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return action + + +def policy_action_to_transition(action: PolicyAction) -> EnvTransition: + """ + Convert a `PolicyAction` to an `EnvTransition`. + """ + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return create_transition(action=action) + + def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition: """ Merge a sequence of transitions into a single one. diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index c3440ff36..644fca180 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -34,7 +34,8 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import batch_to_transition, create_transition, transition_to_batch from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey -# Type variable for generic processor output type +# Type variables for generic processor input and output types +TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") @@ -180,10 +181,13 @@ class ProcessorKwargs(TypedDict, total=False): to_transition: Callable[[dict[str, Any]], EnvTransition] | None to_output: Callable[[EnvTransition], Any] | None + name: str | None + before_step_hooks: list[Callable[[int, EnvTransition], None]] | None + after_step_hooks: list[Callable[[int, EnvTransition], None]] | None @dataclass -class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): +class DataProcessorPipeline(ModelHubMixin, Generic[TInput, TOutput]): """ Composable, debuggable post-processing processor for robot transitions. @@ -217,14 +221,14 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict # For EnvTransition output, explicitly specify identity function - transition_processor: DataProcessorPipeline[EnvTransition] = DataProcessorPipeline( + transition_processor: DataProcessorPipeline[EnvTransition, EnvTransition] = DataProcessorPipeline( steps=[some_step1, some_step2], to_output=lambda x: x, # Identity function ) result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition # For custom output types - processor: DataProcessorPipeline[str] = DataProcessorPipeline( + processor: DataProcessorPipeline[dict[str, Any], str] = DataProcessorPipeline( steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys" ) result: str = processor(batch_data) # Type checker knows this is str @@ -248,7 +252,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): steps: Sequence[ProcessorStep] = field(default_factory=list) name: str = "DataProcessorPipeline" - to_transition: Callable[[dict[str, Any]], EnvTransition] = field(default=batch_to_transition, repr=False) + to_transition: Callable[[TInput], EnvTransition] = field( + default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False + ) to_output: Callable[[EnvTransition], TOutput] = field( # Cast is necessary here: Working around Python type-checker limitation. # _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput @@ -263,7 +269,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) - def __call__(self, data: dict[str, Any]) -> TOutput: + def __call__(self, data: TInput) -> TOutput: """Process data through all steps. The method accepts a batch dictionary (like the ones returned by ReplayBuffer or @@ -299,7 +305,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): hook(idx, transition) return transition - def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]: + def step_through(self, data: TInput) -> Iterable[EnvTransition]: """Yield the intermediate results after each processor step. This is a low-level method that does NOT apply hooks. It simply executes each step @@ -419,10 +425,10 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): revision: str | None = None, config_filename: str | None = None, overrides: dict[str, Any] | None = None, - to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None, + to_transition: Callable[[TInput], EnvTransition] | None = None, to_output: Callable[[EnvTransition], TOutput] | None = None, **kwargs, - ) -> DataProcessorPipeline[TOutput]: + ) -> DataProcessorPipeline[TInput, TOutput]: """Load a serialized processor from source (local path or Hugging Face Hub identifier). Args: @@ -443,7 +449,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): Use identity function (lambda x: x) for EnvTransition output. Returns: - A DataProcessorPipeline[TOutput] instance loaded from the saved configuration. + A DataProcessorPipeline[TInput, TOutput] instance loaded from the saved configuration. Raises: ImportError: If a processor step class cannot be loaded or imported. @@ -652,7 +658,7 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): """Return the number of steps in the processor.""" return len(self.steps) - def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TOutput]: + def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]: """Indexing helper exposing underlying steps. * ``int`` – returns the idx-th ProcessorStep. * ``slice`` – returns a new DataProcessorPipeline with the sliced steps. @@ -755,7 +761,9 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.OBSERVATION] - def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor: + def process_action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: transition: EnvTransition = create_transition(action=action) transformed_transition = self._forward(transition) return transformed_transition[TransitionKey.ACTION] @@ -786,8 +794,8 @@ class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]): return transformed_transition[TransitionKey.COMPLEMENTARY_DATA] -RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput] -PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput] +RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] +PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] class ObservationProcessorStep(ProcessorStep, ABC): diff --git a/src/lerobot/record.py b/src/lerobot/record.py index e0f40f6b4..e82afc21c 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -91,6 +91,7 @@ from lerobot.processor.converters import ( transition_to_dataset_frame, transition_to_robot_action, ) +from lerobot.processor.core import PolicyAction, RobotAction from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, @@ -243,34 +244,37 @@ def record_loop( dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, - preprocessor: PolicyProcessorPipeline | None = None, - postprocessor: PolicyProcessorPipeline | None = None, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None, + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None, control_time_s: int | None = None, - teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after teleop - robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None, # runs before robot - robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after robot + teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition] + | None = None, # runs after teleop + robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction] + | None = None, # runs before robot + robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition] + | None = None, # runs after robot single_task: str | None = None, display_data: bool = False, ): - teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( + teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition] = ( teleop_action_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[RobotAction, EnvTransition]( steps=[IdentityProcessorStep()], to_transition=robot_action_to_transition, to_output=identity_transition, ) ) - robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( + robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction] = ( robot_action_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[EnvTransition, RobotAction]( steps=[IdentityProcessorStep()], to_transition=identity_transition, to_output=transition_to_robot_action, ) ) - robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( + robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition] = ( robot_observation_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[dict[str, Any], EnvTransition]( steps=[IdentityProcessorStep()], to_transition=observation_to_transition, to_output=identity_transition, diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index c641a22d1..3c184703e 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -49,6 +49,7 @@ from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action +from lerobot.processor.core import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -86,7 +87,7 @@ class ReplayConfig: # Use vocal synthesis to read events. play_sounds: bool = True # Optional processor for actions before sending to robot - robot_action_processor: RobotProcessorPipeline | None = None + robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction] | None = None @parser.wrap() @@ -95,10 +96,10 @@ def replay(cfg: ReplayConfig): logging.info(pformat(asdict(cfg))) # Initialize robot action processor with default if not provided - robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline( + robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline[RobotAction, RobotAction]( steps=[IdentityProcessorStep()], to_transition=robot_action_to_transition, - to_output=transition_to_robot_action, # type: ignore[arg-type] + to_output=transition_to_robot_action, ) # Reset processor diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 38501a2e2..7a0bafc68 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -72,7 +72,7 @@ from lerobot.envs.factory import make_env from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor.core import TransitionKey +from lerobot.processor.core import PolicyAction from lerobot.processor.pipeline import PolicyProcessorPipeline from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed @@ -86,8 +86,8 @@ from lerobot.utils.utils import ( def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline[dict[str, Any]], - postprocessor: PolicyProcessorPipeline[dict[str, Any]], + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], seeds: list[int] | None = None, return_observations: bool = False, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, @@ -159,15 +159,15 @@ def rollout( observation = add_envs_task(env, observation) observation = preprocessor(observation) with torch.inference_mode(): - action = policy.select_action(observation) - action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION] + action: PolicyAction = policy.select_action(observation) + action: PolicyAction = postprocessor(action) # Convert to CPU / numpy. - action: np.ndarray = action.to("cpu").numpy() - assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" + action_numpy: np.ndarray = action.to("cpu").numpy() + assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" # Apply the next action. - observation, reward, terminated, truncated, info = env.step(action) + observation, reward, terminated, truncated, info = env.step(action_numpy) if render_callback is not None: render_callback(env) @@ -181,7 +181,7 @@ def rollout( # Keep track of which environments are done so far. done = terminated | truncated | done - all_actions.append(torch.from_numpy(action)) + all_actions.append(torch.from_numpy(action_numpy)) all_rewards.append(torch.from_numpy(reward)) all_dones.append(torch.from_numpy(done)) all_successes.append(torch.tensor(successes)) @@ -220,8 +220,8 @@ def rollout( def eval_policy( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, - preprocessor: PolicyProcessorPipeline[dict[str, Any]], - postprocessor: PolicyProcessorPipeline[dict[str, Any]], + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, max_episodes_rendered: int = 0, videos_dir: Path | None = None, diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index a24c9fb0c..e2a5e10f0 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -69,6 +69,7 @@ from lerobot.processor.converters import ( robot_action_to_transition, transition_to_robot_action, ) +from lerobot.processor.core import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -106,9 +107,15 @@ class TeleoperateConfig: # Display all cameras on screen display_data: bool = False # Optional processors for data transformation - teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop - robot_action_processor: RobotProcessorPipeline | None = None # runs before robot - robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot + teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition] | None = ( + None # runs after teleop + ) + robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction] | None = ( + None # runs before robot + ) + robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition] | None = ( + None # runs after robot + ) def teleop_loop( @@ -117,9 +124,9 @@ def teleop_loop( fps: int, display_data: bool = False, duration: float | None = None, - teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None, - robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None, - robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None, + teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition] | None = None, + robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction] | None = None, + robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition] | None = None, ): """ This function continuously reads actions from a teleoperation device, processes them through optional @@ -137,25 +144,25 @@ def teleop_loop( robot_observation_processor: An optional pipeline to process raw observations from the robot. """ # Initialize processors with defaults if not provided - teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( + teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition] = ( teleop_action_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[RobotAction, EnvTransition]( steps=[IdentityProcessorStep()], to_transition=robot_action_to_transition, to_output=identity_transition, ) ) - robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = ( + robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction] = ( robot_action_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[EnvTransition, RobotAction]( steps=[IdentityProcessorStep()], to_transition=identity_transition, - to_output=transition_to_robot_action, # type: ignore[arg-type] + to_output=transition_to_robot_action, ) ) - robot_observation_processor: RobotProcessorPipeline[EnvTransition] = ( + robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition] = ( robot_observation_processor - or RobotProcessorPipeline( + or RobotProcessorPipeline[dict[str, Any], EnvTransition]( steps=[IdentityProcessorStep()], to_transition=observation_to_transition, to_output=identity_transition, diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 087f35732..febe7070e 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -22,6 +22,7 @@ import traceback from contextlib import nullcontext from copy import copy from functools import cache +from typing import Any import numpy as np import torch @@ -31,7 +32,8 @@ from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import PolicyProcessorPipeline, TransitionKey +from lerobot.processor import PolicyProcessorPipeline +from lerobot.processor.core import PolicyAction from lerobot.robots import Robot @@ -125,8 +127,8 @@ def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, - preprocessor: PolicyProcessorPipeline, - postprocessor: PolicyProcessorPipeline, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], use_amp: bool, task: str | None = None, robot_type: str | None = None, @@ -177,7 +179,7 @@ def predict_action( # based on the current observation action = policy.select_action(observation) - action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION] + action = postprocessor(action) # Remove batch dimension action = action.squeeze(0) diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index 8c663a4c1..548281703 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -33,7 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -93,28 +93,26 @@ def test_act_processor_normalization(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is normalized and batched - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) - assert processed[TransitionKey.ACTION].shape == (1, 4) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) # Process action through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is unnormalized - assert postprocessed[TransitionKey.ACTION].shape == (1, 4) + assert postprocessed.shape == (1, 4) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -127,28 +125,26 @@ def test_act_processor_cuda(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -161,8 +157,6 @@ def test_act_processor_accelerate_scenario(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU @@ -170,13 +164,14 @@ def test_act_processor_accelerate_scenario(): observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU action = torch.randn(1, 4).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on same GPU (not moved unnecessarily) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -189,7 +184,6 @@ def test_act_processor_multi_gpu(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU (like in multi-GPU training) @@ -197,13 +191,14 @@ def test_act_processor_multi_gpu(): observation = {OBS_STATE: torch.randn(1, 7).to(device)} action = torch.randn(1, 4).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on cuda:1 (not moved to cuda:0) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_act_processor_without_stats(): @@ -213,8 +208,6 @@ def test_act_processor_without_stats(): preprocessor, postprocessor = make_act_pre_post_processors( config, dataset_stats=None, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Should still create processors, but normalization won't have stats @@ -225,8 +218,9 @@ def test_act_processor_without_stats(): observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = preprocessor(transition) + processed = preprocessor(batch) assert processed is not None @@ -238,8 +232,6 @@ def test_act_processor_save_and_load(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) with tempfile.TemporaryDirectory() as tmpdir: @@ -247,18 +239,17 @@ def test_act_processor_save_and_load(): preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) - assert processed[TransitionKey.ACTION].shape == (1, 4) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) def test_act_processor_device_placement_preservation(): @@ -271,18 +262,17 @@ def test_act_processor_device_placement_preservation(): preprocessor, _ = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Process CPU data observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" - assert processed[TransitionKey.ACTION].device.type == "cpu" + processed = preprocessor(batch) + assert processed[OBS_STATE].device.type == "cpu" + assert processed[TransitionKey.ACTION.value].device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -296,8 +286,6 @@ def test_act_processor_mixed_precision(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Replace DeviceProcessorStep with one that uses float16 @@ -307,11 +295,12 @@ def test_act_processor_mixed_precision(): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) elif isinstance(step, NormalizerProcessorStep): # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float16, # Match the float16 dtype ) @@ -324,13 +313,14 @@ def test_act_processor_mixed_precision(): observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} action = torch.randn(4, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 def test_act_processor_batch_consistency(): @@ -341,26 +331,26 @@ def test_act_processor_batch_consistency(): preprocessor, postprocessor = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test single sample (unbatched) observation = {OBS_STATE: torch.randn(7)} action = torch.randn(4) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched + processed = preprocessor(batch) + assert processed["observation.state"].shape[0] == 1 # Batched # Test already batched data observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 action_batched = torch.randn(8, 4) transition_batched = create_transition(observation_batched, action_batched) + batch_batched = transition_to_batch(transition_batched) - processed_batched = preprocessor(transition_batched) - assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8 - assert processed_batched[TransitionKey.ACTION].shape[0] == 8 + processed_batched = preprocessor(batch_batched) + assert processed_batched[OBS_STATE].shape[0] == 8 + assert processed_batched[TransitionKey.ACTION.value].shape[0] == 8 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -373,7 +363,6 @@ def test_act_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_act_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -384,11 +373,12 @@ def test_act_processor_bfloat16_device_float32_normalizer(): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) elif isinstance(step, NormalizerProcessorStep): # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float32, # Deliberately configured as float32 ) @@ -405,13 +395,14 @@ def test_act_processor_bfloat16_device_float32_normalizer(): observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32 action = torch.randn(4, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_classifier_processor.py b/tests/processor/test_classifier_processor.py index 75c65a4dc..c12844793 100644 --- a/tests/processor/test_classifier_processor.py +++ b/tests/processor/test_classifier_processor.py @@ -31,7 +31,7 @@ from lerobot.processor import ( NormalizerProcessorStep, TransitionKey, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -93,8 +93,6 @@ def test_classifier_processor_normalization(): preprocessor, postprocessor = make_classifier_processor( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data @@ -104,14 +102,15 @@ def test_classifier_processor_normalization(): } action = torch.randn(1) # Dummy action/reward transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is processed - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1,) + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -124,8 +123,6 @@ def test_classifier_processor_cuda(): preprocessor, postprocessor = make_classifier_processor( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -136,20 +133,22 @@ def test_classifier_processor_cuda(): action = torch.randn(1) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - reward_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(reward_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that output is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -162,8 +161,6 @@ def test_classifier_processor_accelerate_scenario(): preprocessor, postprocessor = make_classifier_processor( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU @@ -175,13 +172,16 @@ def test_classifier_processor_accelerate_scenario(): action = torch.randn(1).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -202,13 +202,16 @@ def test_classifier_processor_multi_gpu(): action = torch.randn(1).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_classifier_processor_without_stats(): @@ -229,7 +232,9 @@ def test_classifier_processor_without_stats(): action = torch.randn(1) transition = create_transition(observation, action) - processed = preprocessor(transition) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) assert processed is not None @@ -238,22 +243,14 @@ def test_classifier_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - factory_preprocessor.steps, to_transition=identity_transition, to_output=identity_transition - ) + preprocessor, postprocessor = make_classifier_processor(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = { @@ -262,11 +259,12 @@ def test_classifier_processor_save_and_load(): } action = torch.randn(1) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1,) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -276,21 +274,16 @@ def test_classifier_processor_mixed_precision(): config.device = "cuda" stats = create_default_stats() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats) + preprocessor, postprocessor = make_classifier_processor(config, stats) # Replace DeviceProcessorStep with one that uses float16 modified_steps = [] - for step in factory_preprocessor.steps: + for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) else: modified_steps.append(step) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - modified_steps, to_transition=identity_transition, to_output=identity_transition - ) + preprocessor.steps = modified_steps # Create test data observation = { @@ -300,13 +293,16 @@ def test_classifier_processor_mixed_precision(): action = torch.randn(1, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 def test_classifier_processor_batch_data(): @@ -317,8 +313,6 @@ def test_classifier_processor_batch_data(): preprocessor, postprocessor = make_classifier_processor( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with batched data @@ -330,13 +324,16 @@ def test_classifier_processor_batch_data(): action = torch.randn(batch_size, 1) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that batch dimension is preserved - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 10) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (batch_size, 1) + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 1) def test_classifier_processor_postprocessor_identity(): @@ -347,17 +344,17 @@ def test_classifier_processor_postprocessor_identity(): preprocessor, postprocessor = make_classifier_processor( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data for postprocessor reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions transition = create_transition(action=reward) + _ = transition_to_batch(transition) + # Process through postprocessor - processed = postprocessor(transition) + processed = postprocessor(reward) # IdentityProcessor should leave values unchanged (except device) - assert torch.allclose(processed[TransitionKey.ACTION].cpu(), reward.cpu()) - assert processed[TransitionKey.ACTION].device.type == "cpu" + assert torch.allclose(processed.cpu(), reward.cpu()) + assert processed.device.type == "cpu" diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py index 98215dc18..79fdb9673 100644 --- a/tests/processor/test_diffusion_processor.py +++ b/tests/processor/test_diffusion_processor.py @@ -33,7 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -96,8 +96,6 @@ def test_diffusion_processor_with_images(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data with images @@ -108,13 +106,16 @@ def test_diffusion_processor_with_images(): action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is batched - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 6) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -127,8 +128,6 @@ def test_diffusion_processor_cuda(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -139,20 +138,22 @@ def test_diffusion_processor_cuda(): action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -165,8 +166,6 @@ def test_diffusion_processor_accelerate_scenario(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU @@ -178,13 +177,16 @@ def test_diffusion_processor_accelerate_scenario(): action = torch.randn(1, 6).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -205,13 +207,16 @@ def test_diffusion_processor_multi_gpu(): action = torch.randn(1, 6).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_diffusion_processor_without_stats(): @@ -221,7 +226,6 @@ def test_diffusion_processor_without_stats(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, dataset_stats=None, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Should still create processors @@ -236,7 +240,9 @@ def test_diffusion_processor_without_stats(): action = torch.randn(6) transition = create_transition(observation, action) - processed = preprocessor(transition) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) assert processed is not None @@ -245,22 +251,14 @@ def test_diffusion_processor_save_and_load(): config = create_default_config() stats = create_default_stats() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - factory_preprocessor.steps, to_transition=identity_transition, to_output=identity_transition - ) + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) with tempfile.TemporaryDirectory() as tmpdir: # Save preprocessor preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = { @@ -269,62 +267,12 @@ def test_diffusion_processor_save_and_load(): } action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 6) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_diffusion_processor_mixed_precision(): - """Test Diffusion processor with mixed precision.""" - config = create_default_config() - config.device = "cuda" - stats = create_default_stats() - - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats) - - # Replace DeviceProcessorStep with one that uses float16 - modified_steps = [] - for step in factory_preprocessor.steps: - if isinstance(step, DeviceProcessorStep): - modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) - elif isinstance(step, NormalizerProcessorStep): - # Update normalizer to use the same device as the device processor - modified_steps.append( - NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, - device=config.device, - dtype=torch.float16, # Match the float16 dtype - ) - ) - else: - modified_steps.append(step) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - modified_steps, to_transition=identity_transition, to_output=identity_transition - ) - - # Create test data - observation = { - OBS_STATE: torch.randn(7, dtype=torch.float32), - OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), - } - action = torch.randn(6, dtype=torch.float32) - transition = create_transition(observation, action) - - # Process through preprocessor - processed = preprocessor(transition) - - # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) def test_diffusion_processor_identity_normalization(): @@ -335,8 +283,6 @@ def test_diffusion_processor_identity_normalization(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data @@ -348,12 +294,15 @@ def test_diffusion_processor_identity_normalization(): action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Image should not be normalized (IDENTITY mode) # Just batched - assert torch.allclose(processed[TransitionKey.OBSERVATION][OBS_IMAGE][0], image_value, rtol=1e-5) + assert torch.allclose(processed[OBS_IMAGE][0], image_value, rtol=1e-5) def test_diffusion_processor_batch_consistency(): @@ -364,8 +313,6 @@ def test_diffusion_processor_batch_consistency(): preprocessor, postprocessor = make_diffusion_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with different batch sizes @@ -377,13 +324,15 @@ def test_diffusion_processor_batch_consistency(): action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6) transition = create_transition(observation, action) - processed = preprocessor(transition) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) # Check correct batch size expected_batch = batch_size if batch_size > 1 else 1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch - assert processed[TransitionKey.ACTION].shape[0] == expected_batch + assert processed[OBS_STATE].shape[0] == expected_batch + assert processed[OBS_IMAGE].shape[0] == expected_batch + assert processed[TransitionKey.ACTION.value].shape[0] == expected_batch @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -393,36 +342,32 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer(): config.device = "cuda" stats = create_default_stats() - # Get the steps from the factory function - factory_preprocessor, _ = make_diffusion_pre_post_processors(config, stats) + preprocessor, _ = make_diffusion_pre_post_processors(config, stats) # Modify the pipeline to use bfloat16 device processor with float32 normalizer modified_steps = [] - for step in factory_preprocessor.steps: + for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): # Device processor converts to bfloat16 modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) elif isinstance(step, NormalizerProcessorStep): # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float32, # Deliberately configured as float32 ) ) else: modified_steps.append(step) - - # Create new processor with modified steps - preprocessor = DataProcessorPipeline( - modified_steps, to_transition=identity_transition, to_output=identity_transition - ) + preprocessor.steps = modified_steps # Verify initial normalizer configuration - normalizer_step = modified_steps[3] # NormalizerProcessorStep + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep assert normalizer_step.dtype == torch.float32 # Create test data with both state and visual observations @@ -433,15 +378,15 @@ def test_diffusion_processor_bfloat16_device_float32_normalizer(): action = torch.randn(6, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert ( - processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 - ) # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index 1745e1779..f2a7e36d0 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -34,7 +34,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch class MockTokenizerProcessorStep(ProcessorStep): @@ -91,8 +91,6 @@ def test_make_pi0_processor_basic(): preprocessor, postprocessor = make_pi0_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Check processor names @@ -195,8 +193,6 @@ def test_pi0_processor_cuda(): preprocessor, postprocessor = make_pi0_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -206,14 +202,15 @@ def test_pi0_processor_cuda(): } action = torch.randn(6) transition = create_transition(observation, action, complementary_data={"task": "test task"}) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -250,8 +247,6 @@ def test_pi0_processor_accelerate_scenario(): preprocessor, postprocessor = make_pi0_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU and batched @@ -262,14 +257,15 @@ def test_pi0_processor_accelerate_scenario(): } action = torch.randn(1, 6).to(device) transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -306,8 +302,6 @@ def test_pi0_processor_multi_gpu(): preprocessor, postprocessor = make_pi0_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU @@ -318,14 +312,15 @@ def test_pi0_processor_multi_gpu(): } action = torch.randn(1, 6).to(device) transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_pi0_processor_without_stats(): @@ -337,8 +332,6 @@ def test_pi0_processor_without_stats(): preprocessor, postprocessor = make_pi0_pre_post_processors( config, dataset_stats=None, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Should still create processors @@ -376,8 +369,6 @@ def test_pi0_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_pi0_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -388,11 +379,12 @@ def test_pi0_processor_bfloat16_device_float32_normalizer(): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) elif isinstance(step, NormalizerProcessorStep): # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float32, # Deliberately configured as float32 ) @@ -414,16 +406,15 @@ def test_pi0_processor_bfloat16_device_float32_normalizer(): transition = create_transition( observation, action, complementary_data={"task": "test bfloat16 adaptation"} ) + batch = transition_to_batch(transition) # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert ( - processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 - ) # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index 8d2bd8453..71f5c366e 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -33,7 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -69,8 +69,6 @@ def test_make_sac_processor_basic(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Check processor names @@ -98,30 +96,28 @@ def test_sac_processor_normalization_modes(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization action = torch.rand(5) * 2 - 1 # Range [-1, 1] transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is normalized and batched # State should be mean-std normalized # Action should be min-max normalized to [-1, 1] - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) - assert processed[TransitionKey.ACTION].shape == (1, 5) + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) # Process action through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is unnormalized (but still batched) - assert postprocessed[TransitionKey.ACTION].shape == (1, 5) + assert postprocessed.shape == (1, 5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -134,28 +130,26 @@ def test_sac_processor_cuda(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -168,8 +162,6 @@ def test_sac_processor_accelerate_scenario(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU @@ -177,13 +169,14 @@ def test_sac_processor_accelerate_scenario(): observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -196,8 +189,6 @@ def test_sac_processor_multi_gpu(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU @@ -205,35 +196,21 @@ def test_sac_processor_multi_gpu(): observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_sac_processor_without_stats(): """Test SAC processor creation without dataset statistics.""" config = create_default_config() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - factory_preprocessor.steps, - name=factory_preprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) - postprocessor = DataProcessorPipeline( - factory_postprocessor.steps, - name=factory_postprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) + preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -243,8 +220,9 @@ def test_sac_processor_without_stats(): observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = preprocessor(transition) + processed = preprocessor(batch) assert processed is not None @@ -256,8 +234,6 @@ def test_sac_processor_save_and_load(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) with tempfile.TemporaryDirectory() as tmpdir: @@ -265,18 +241,17 @@ def test_sac_processor_save_and_load(): preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) - assert processed[TransitionKey.ACTION].shape == (1, 5) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -290,8 +265,6 @@ def test_sac_processor_mixed_precision(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Replace DeviceProcessorStep with one that uses float16 @@ -301,11 +274,12 @@ def test_sac_processor_mixed_precision(): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) elif isinstance(step, NormalizerProcessorStep): # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float16, # Match the float16 dtype ) @@ -318,13 +292,14 @@ def test_sac_processor_mixed_precision(): observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} action = torch.randn(5, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 def test_sac_processor_batch_data(): @@ -335,8 +310,6 @@ def test_sac_processor_batch_data(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with batched data @@ -344,13 +317,14 @@ def test_sac_processor_batch_data(): observation = {OBS_STATE: torch.randn(batch_size, 10)} action = torch.randn(batch_size, 5) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through preprocessor - processed = preprocessor(transition) + processed = preprocessor(batch) # Check that batch dimension is preserved - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 10) - assert processed[TransitionKey.ACTION].shape == (batch_size, 5) + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5) def test_sac_processor_edge_cases(): @@ -361,22 +335,24 @@ def test_sac_processor_edge_cases(): preprocessor, postprocessor = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) - # Test with empty observation - transition = create_transition(observation={}, action=torch.randn(5)) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION] == {} - assert processed[TransitionKey.ACTION].shape == (1, 5) + # Test with observation that has no state key but still exists + observation = {"observation.dummy": torch.randn(1)} # Some dummy observation to pass validation + action = torch.randn(5) + batch = {TransitionKey.ACTION.value: action, **observation} + processed = preprocessor(batch) + # observation.state wasn't in original, so it won't be in processed + assert OBS_STATE not in processed + assert processed[TransitionKey.ACTION.value].shape == (1, 5) # Test with zero action (representing "null" action) transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5)) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) + batch = transition_to_batch(transition) + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) # Action should be present and batched, even if it's zeros - assert processed[TransitionKey.ACTION].shape == (1, 5) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -389,8 +365,6 @@ def test_sac_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_sac_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -401,11 +375,12 @@ def test_sac_processor_bfloat16_device_float32_normalizer(): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) elif isinstance(step, NormalizerProcessorStep): # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep modified_steps.append( NormalizerProcessorStep( - features=step.features, - norm_map=step.norm_map, - stats=step.stats, + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, device=config.device, dtype=torch.float32, # Deliberately configured as float32 ) @@ -422,13 +397,14 @@ def test_sac_processor_bfloat16_device_float32_normalizer(): observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32 action = torch.randn(5, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index f2dd0156f..c37cd3eef 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -37,7 +37,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch class MockTokenizerProcessorStep(ProcessorStep): @@ -98,8 +98,6 @@ def test_make_smolvla_processor_basic(): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Check processor names @@ -204,8 +202,6 @@ def test_smolvla_processor_cuda(): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -216,13 +212,16 @@ def test_smolvla_processor_cuda(): action = torch.randn(7) transition = create_transition(observation, action, complementary_data={"task": "test task"}) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -261,8 +260,6 @@ def test_smolvla_processor_accelerate_scenario(): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU and batched @@ -274,13 +271,16 @@ def test_smolvla_processor_accelerate_scenario(): action = torch.randn(1, 7).to(device) transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -319,8 +319,6 @@ def test_smolvla_processor_multi_gpu(): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU @@ -332,13 +330,16 @@ def test_smolvla_processor_multi_gpu(): action = torch.randn(1, 7).to(device) transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_smolvla_processor_without_stats(): @@ -352,8 +353,6 @@ def test_smolvla_processor_without_stats(): preprocessor, postprocessor = make_smolvla_pre_post_processors( config, dataset_stats=None, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Should still create processors @@ -405,8 +404,6 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_smolvla_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -444,15 +441,15 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer(): observation, action, complementary_data={"task": "test bfloat16 adaptation"} ) + batch = transition_to_batch(transition) + # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert ( - processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 - ) # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py index 660fe10ea..1aae97328 100644 --- a/tests/processor/test_tdmpc_processor.py +++ b/tests/processor/test_tdmpc_processor.py @@ -33,7 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -72,8 +72,6 @@ def test_make_tdmpc_processor_basic(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Check processor names @@ -101,8 +99,6 @@ def test_tdmpc_processor_normalization(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data @@ -113,20 +109,22 @@ def test_tdmpc_processor_normalization(): action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is processed and batched - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 6) + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) # Process action through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is unnormalized (but still batched) - assert postprocessed[TransitionKey.ACTION].shape == (1, 6) + assert postprocessed.shape == (1, 6) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -139,8 +137,6 @@ def test_tdmpc_processor_cuda(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -151,20 +147,22 @@ def test_tdmpc_processor_cuda(): action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -177,8 +175,6 @@ def test_tdmpc_processor_accelerate_scenario(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU @@ -190,13 +186,16 @@ def test_tdmpc_processor_accelerate_scenario(): action = torch.randn(6).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -209,8 +208,6 @@ def test_tdmpc_processor_multi_gpu(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU @@ -222,35 +219,23 @@ def test_tdmpc_processor_multi_gpu(): action = torch.randn(6).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_tdmpc_processor_without_stats(): """Test TDMPC processor creation without dataset statistics.""" config = create_default_config() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - factory_preprocessor.steps, - name=factory_preprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) - postprocessor = DataProcessorPipeline( - factory_postprocessor.steps, - name=factory_postprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -263,8 +248,9 @@ def test_tdmpc_processor_without_stats(): } action = torch.randn(6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) - processed = preprocessor(transition) + processed = preprocessor(batch) assert processed is not None @@ -276,8 +262,6 @@ def test_tdmpc_processor_save_and_load(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) with tempfile.TemporaryDirectory() as tmpdir: @@ -285,9 +269,7 @@ def test_tdmpc_processor_save_and_load(): preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = { @@ -297,10 +279,11 @@ def test_tdmpc_processor_save_and_load(): action = torch.randn(6) transition = create_transition(observation, action) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 6) + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -314,8 +297,6 @@ def test_tdmpc_processor_mixed_precision(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Replace DeviceProcessorStep with one that uses float16 @@ -346,13 +327,16 @@ def test_tdmpc_processor_mixed_precision(): action = torch.randn(6, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 def test_tdmpc_processor_batch_data(): @@ -363,8 +347,6 @@ def test_tdmpc_processor_batch_data(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with batched data @@ -376,13 +358,16 @@ def test_tdmpc_processor_batch_data(): action = torch.randn(batch_size, 6) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that batch dimension is preserved - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 12) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (batch_size, 6) + assert processed[OBS_STATE].shape == (batch_size, 12) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 6) def test_tdmpc_processor_edge_cases(): @@ -393,8 +378,6 @@ def test_tdmpc_processor_edge_cases(): preprocessor, postprocessor = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with only state observation (no image) @@ -402,17 +385,21 @@ def test_tdmpc_processor_edge_cases(): action = torch.randn(6) transition = create_transition(observation, action) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 12) - assert OBS_IMAGE not in processed[TransitionKey.OBSERVATION] + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert OBS_IMAGE not in processed # Test with only image observation (no state) observation = {OBS_IMAGE: torch.randn(3, 224, 224)} transition = create_transition(observation, action) - processed = preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert OBS_STATE not in processed[TransitionKey.OBSERVATION] + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert OBS_STATE not in processed @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -425,7 +412,6 @@ def test_tdmpc_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_tdmpc_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -461,15 +447,15 @@ def test_tdmpc_processor_bfloat16_device_float32_normalizer(): action = torch.randn(6, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert ( - processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 - ) # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16 diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py index 0d773993e..d80d85101 100644 --- a/tests/processor/test_vqbet_processor.py +++ b/tests/processor/test_vqbet_processor.py @@ -33,7 +33,7 @@ from lerobot.processor import ( TransitionKey, UnnormalizerProcessorStep, ) -from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.converters import create_transition, transition_to_batch def create_default_config(): @@ -72,8 +72,6 @@ def test_make_vqbet_processor_basic(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Check processor names @@ -101,8 +99,6 @@ def test_vqbet_processor_with_images(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create test data with images and states @@ -113,13 +109,16 @@ def test_vqbet_processor_with_images(): action = torch.randn(7) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is batched - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 7) + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -132,8 +131,6 @@ def test_vqbet_processor_cuda(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Create CPU data @@ -144,20 +141,22 @@ def test_vqbet_processor_cuda(): action = torch.randn(7) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is on CUDA - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION].device.type == "cuda" + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" # Process through postprocessor - action_transition = create_transition(action=processed[TransitionKey.ACTION]) - postprocessed = postprocessor(action_transition) + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) # Check that action is back on CPU - assert postprocessed[TransitionKey.ACTION].device.type == "cpu" + assert postprocessed.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -170,8 +169,6 @@ def test_vqbet_processor_accelerate_scenario(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate Accelerate: data already on GPU and batched @@ -183,13 +180,16 @@ def test_vqbet_processor_accelerate_scenario(): action = torch.randn(1, 7).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on same GPU - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -202,8 +202,6 @@ def test_vqbet_processor_multi_gpu(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Simulate data on different GPU @@ -215,35 +213,23 @@ def test_vqbet_processor_multi_gpu(): action = torch.randn(1, 7).to(device) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data stays on cuda:1 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION].device == device + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device def test_vqbet_processor_without_stats(): """Test VQBeT processor creation without dataset statistics.""" config = create_default_config() - # Get the steps from the factory function - factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None) - - # Create new processors with EnvTransition input/output - preprocessor = DataProcessorPipeline( - factory_preprocessor.steps, - name=factory_preprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) - postprocessor = DataProcessorPipeline( - factory_postprocessor.steps, - name=factory_postprocessor.name, - to_transition=identity_transition, - to_output=identity_transition, - ) + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None) # Should still create processors assert preprocessor is not None @@ -257,7 +243,9 @@ def test_vqbet_processor_without_stats(): action = torch.randn(7) transition = create_transition(observation, action) - processed = preprocessor(transition) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) assert processed is not None @@ -269,8 +257,6 @@ def test_vqbet_processor_save_and_load(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) with tempfile.TemporaryDirectory() as tmpdir: @@ -278,9 +264,7 @@ def test_vqbet_processor_save_and_load(): preprocessor.save_pretrained(tmpdir) # Load preprocessor - loaded_preprocessor = DataProcessorPipeline.from_pretrained( - tmpdir, to_transition=identity_transition, to_output=identity_transition - ) + loaded_preprocessor = DataProcessorPipeline.from_pretrained(tmpdir) # Test that loaded processor works observation = { @@ -290,10 +274,11 @@ def test_vqbet_processor_save_and_load(): action = torch.randn(7) transition = create_transition(observation, action) - processed = loaded_preprocessor(transition) - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (1, 7) + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -307,8 +292,6 @@ def test_vqbet_processor_mixed_precision(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Replace DeviceProcessorStep with one that uses float16 @@ -339,13 +322,16 @@ def test_vqbet_processor_mixed_precision(): action = torch.randn(7, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that data is converted to float16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.float16 - assert processed[TransitionKey.ACTION].dtype == torch.float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 def test_vqbet_processor_large_batch(): @@ -356,8 +342,6 @@ def test_vqbet_processor_large_batch(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Test with large batch @@ -369,13 +353,16 @@ def test_vqbet_processor_large_batch(): action = torch.randn(batch_size, 7) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through preprocessor - processed = preprocessor(transition) + + processed = preprocessor(batch) # Check that batch dimension is preserved - assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (batch_size, 8) - assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (batch_size, 3, 224, 224) - assert processed[TransitionKey.ACTION].shape == (batch_size, 7) + assert processed[OBS_STATE].shape == (batch_size, 8) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 7) def test_vqbet_processor_sequential_processing(): @@ -386,8 +373,6 @@ def test_vqbet_processor_sequential_processing(): preprocessor, postprocessor = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Process multiple samples sequentially @@ -400,14 +385,16 @@ def test_vqbet_processor_sequential_processing(): action = torch.randn(7) transition = create_transition(observation, action) - processed = preprocessor(transition) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) results.append(processed) # Check that all results are consistent for result in results: - assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) - assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) - assert result[TransitionKey.ACTION].shape == (1, 7) + assert result[OBS_STATE].shape == (1, 8) + assert result[OBS_IMAGE].shape == (1, 3, 224, 224) + assert result[TransitionKey.ACTION.value].shape == (1, 7) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -420,8 +407,6 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer(): preprocessor, _ = make_vqbet_pre_post_processors( config, stats, - preprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, - postprocessor_kwargs={"to_transition": identity_transition, "to_output": identity_transition}, ) # Modify the pipeline to use bfloat16 device processor with float32 normalizer @@ -457,15 +442,15 @@ def test_vqbet_processor_bfloat16_device_float32_normalizer(): action = torch.randn(7, dtype=torch.float32) transition = create_transition(observation, action) + batch = transition_to_batch(transition) + # Process through full pipeline - processed = preprocessor(transition) + processed = preprocessor(batch) # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 - assert ( - processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 - ) # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 # Verify normalizer automatically adapted its internal state assert normalizer_step.dtype == torch.bfloat16