From 43eb0e375f858ede324d592b94c30fa180c98c99 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 16 Sep 2025 17:13:07 +0200 Subject: [PATCH] fix(processor): enforce signatures --- examples/so100_to_so100_EE/record.py | 5 ++--- src/lerobot/processor/converters.py | 22 ++++++++++++++++++---- src/lerobot/replay.py | 13 +++---------- src/lerobot/teleoperate.py | 20 +++++++++----------- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index b6c137d6d..41dab2fdd 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -14,14 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotProcessorPipeline +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_to_transition, @@ -76,7 +75,7 @@ leader_kinematics_solver = RobotKinematics( ) # Build pipeline to convert follower joints to EE observation -follower_joints_to_ee = RobotProcessorPipeline[dict[str, Any], dict[str, Any]]( +follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( steps=[ ForwardKinematicsJointsToEE( kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys()) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index be38feb16..60c7e6579 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -221,7 +221,8 @@ def robot_action_to_transition(action: RobotAction) -> EnvTransition: Returns: An `EnvTransition` containing the formatted action. """ - + if not isinstance(action, RobotAction): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") return create_transition(action=action) @@ -239,7 +240,8 @@ def observation_to_transition(observation: RobotObservation) -> EnvTransition: Returns: An `EnvTransition` containing the formatted observation. """ - + if not isinstance(observation, RobotObservation): + raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}") return create_transition(observation=observation) @@ -256,6 +258,9 @@ def transition_to_robot_action(transition: EnvTransition) -> RobotAction: Returns: A dictionary representing the raw robot action. """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + action = transition.get(TransitionKey.ACTION) if not isinstance(action, dict): raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}") @@ -266,6 +271,9 @@ def transition_to_policy_action(transition: EnvTransition) -> PolicyAction: """ Convert an `EnvTransition` to a `PolicyAction`. """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + action = transition.get(TransitionKey.ACTION) if not isinstance(action, PolicyAction): raise ValueError(f"Action should be a PolicyAction type got {type(action)}") @@ -276,6 +284,9 @@ def transition_to_observation(transition: EnvTransition) -> RobotObservation: """ Convert an `EnvTransition` to a `RobotObservation`. """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + observation = transition.get(TransitionKey.OBSERVATION) if not isinstance(observation, dict): raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}") @@ -343,6 +354,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: Returns: A batch dictionary with canonical LeRobot field names. """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + batch = { "action": transition.get(TransitionKey.ACTION), "next.reward": transition.get(TransitionKey.REWARD, 0.0), @@ -364,7 +378,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: return batch -def identity_transition(tr: EnvTransition) -> EnvTransition: +def identity_transition(transition: EnvTransition) -> EnvTransition: """ An identity function for transitions, returning the input unchanged. @@ -376,4 +390,4 @@ def identity_transition(tr: EnvTransition) -> EnvTransition: Returns: The same `EnvTransition`. """ - return tr + return transition diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index d92520b8a..21446434a 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -48,11 +48,8 @@ from pprint import pformat from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.processor import ( - IdentityProcessorStep, - RobotAction, - RobotProcessorPipeline, + make_default_robot_action_processor, ) -from lerobot.processor.converters import robot_action_to_transition, transition_to_robot_action from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -96,11 +93,7 @@ def replay(cfg: ReplayConfig): init_logging() logging.info(pformat(asdict(cfg))) - robot_action_processor = RobotProcessorPipeline[RobotAction, RobotAction]( - steps=[IdentityProcessorStep()], - to_transition=robot_action_to_transition, - to_output=transition_to_robot_action, - ) + robot_action_processor = make_default_robot_action_processor() robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) @@ -122,7 +115,7 @@ def replay(cfg: ReplayConfig): processed_action = robot_action_processor(action) - robot.send_action(processed_action) + _ = robot.send_action(processed_action) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / dataset.fps - dt_s) diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 5f1fbdddd..3abd60cdf 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -55,7 +55,6 @@ import logging import time from dataclasses import asdict, dataclass from pprint import pformat -from typing import Any import rerun as rr @@ -63,10 +62,9 @@ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # no from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.processor import ( - EnvTransition, RobotAction, + RobotObservation, RobotProcessorPipeline, - TransitionKey, make_default_processors, ) from lerobot.robots import ( # noqa: F401 @@ -111,9 +109,9 @@ def teleop_loop( teleop: Teleoperator, robot: Robot, fps: int, - teleop_action_processor: RobotProcessorPipeline[RobotAction, EnvTransition], - robot_action_processor: RobotProcessorPipeline[EnvTransition, RobotAction], - robot_observation_processor: RobotProcessorPipeline[dict[str, Any], EnvTransition], + teleop_action_processor: RobotProcessorPipeline[RobotAction, RobotAction], + robot_action_processor: RobotProcessorPipeline[RobotAction, RobotAction], + robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], display_data: bool = False, duration: float | None = None, ): @@ -143,13 +141,13 @@ def teleop_loop( raw_action = teleop.get_action() # Process teleop action through pipeline - teleop_transition = teleop_action_processor(raw_action) + teleop_action = teleop_action_processor(raw_action) # Process action for robot through pipeline - robot_action_to_send = robot_action_processor(teleop_transition) + robot_action_to_send = robot_action_processor(teleop_action) # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) - robot.send_action(robot_action_to_send) + _ = robot.send_action(robot_action_to_send) if display_data: # Get robot observation @@ -158,8 +156,8 @@ def teleop_loop( obs_transition = robot_observation_processor(obs) log_rerun_data( - observation=obs_transition.get(TransitionKey.OBSERVATION), - action=teleop_transition.get(TransitionKey.ACTION), + observation=obs_transition, + action=teleop_action, ) print("\n" + "-" * (display_len + 10))