diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 48fcb7407..baef233c5 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -65,7 +65,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints = RobotProcessor( +robot_ee_to_joints_processor = RobotProcessor( steps=[ AddRobotObservationAsComplimentaryData(robot=robot), InverseKinematicsEEToJoints( @@ -79,7 +79,7 @@ robot_ee_to_joints = RobotProcessor( ) # Build pipeline to convert joint observation to ee pose observation -robot_joints_to_ee_pose = RobotProcessor( +robot_joints_to_ee_pose_processor = RobotProcessor( steps=[ ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) ], @@ -89,7 +89,7 @@ robot_joints_to_ee_pose = RobotProcessor( # Build dataset action and gripper features action_ee_and_gripper = aggregate_pipeline_dataset_features( - pipeline=robot_ee_to_joints, + pipeline=robot_ee_to_joints_processor, initial_features={}, use_videos=True, patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"], @@ -97,7 +97,7 @@ action_ee_and_gripper = aggregate_pipeline_dataset_features( # Build dataset observation features obs_ee = aggregate_pipeline_dataset_features( - pipeline=robot_joints_to_ee_pose, + pipeline=robot_joints_to_ee_pose_processor, initial_features=robot.observation_features, use_videos=True, patterns=["observation.state.ee"], @@ -147,8 +147,8 @@ for episode_idx in range(NUM_EPISODES): control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - robot_action_processor=robot_ee_to_joints, - robot_observation_processor=robot_joints_to_ee_pose, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, ) dataset.save_episode() diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 32723c280..bde2ebc2d 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -73,7 +73,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert phone action to ee pose action -phone_to_robot_ee_pose = RobotProcessor( +phone_to_robot_ee_pose_processor = RobotProcessor( steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), AddRobotObservationAsComplimentaryData(robot=robot), @@ -93,7 +93,7 @@ phone_to_robot_ee_pose = RobotProcessor( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints = RobotProcessor( +robot_ee_to_joints_processor = RobotProcessor( steps=[ InverseKinematicsEEToJoints( kinematics=kinematics_solver, @@ -120,7 +120,7 @@ robot_joints_to_ee_pose = RobotProcessor( # Build dataset ee action features action_ee = aggregate_pipeline_dataset_features( - pipeline=phone_to_robot_ee_pose, + pipeline=phone_to_robot_ee_pose_processor, initial_features=phone.action_features, use_videos=True, patterns=["action.ee"], @@ -128,7 +128,7 @@ action_ee = aggregate_pipeline_dataset_features( # Get gripper pos action features gripper = aggregate_pipeline_dataset_features( - pipeline=robot_ee_to_joints, + pipeline=robot_ee_to_joints_processor, initial_features={}, use_videos=True, patterns=["action.gripper.pos", "observation.state.gripper.pos"], @@ -177,8 +177,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=phone_to_robot_ee_pose, - robot_action_processor=robot_ee_to_joints, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, robot_observation_processor=robot_joints_to_ee_pose, ) @@ -193,8 +193,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, - teleop_action_processor=phone_to_robot_ee_pose, - robot_action_processor=robot_ee_to_joints, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, robot_observation_processor=robot_joints_to_ee_pose, ) diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 83938d7ca..f7c0d395f 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -50,7 +50,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert ee pose action to joint action -robot_ee_to_joints = RobotProcessor( +robot_ee_to_joints_processor = RobotProcessor( steps=[ AddRobotObservationAsComplimentaryData(robot=robot), InverseKinematicsEEToJoints( @@ -63,7 +63,7 @@ robot_ee_to_joints = RobotProcessor( to_output=to_output_robot_action, ) -robot_ee_to_joints.reset() +robot_ee_to_joints_processor.reset() log_say(f"Replaying episode {EPISODE_IDX}") for idx in range(dataset.num_frames): @@ -73,7 +73,7 @@ for idx in range(dataset.num_frames): name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) } - joint_action = robot_ee_to_joints(ee_action) + joint_action = robot_ee_to_joints_processor(ee_action) action_sent = robot.send_action(joint_action) busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 1eef0f8ae..e937265a6 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -49,7 +49,7 @@ kinematics_solver = RobotKinematics( ) # Build pipeline to convert phone action to ee pose action to joint action -phone_to_robot_joints = RobotProcessor( +phone_to_robot_joints_processor = RobotProcessor( steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), AddRobotObservationAsComplimentaryData(robot=robot), @@ -85,7 +85,7 @@ while True: phone_obs = teleop_device.get_action() # Phone -> EE pose -> Joints transition - joint_action = phone_to_robot_joints(phone_obs) + joint_action = phone_to_robot_joints_processor(phone_obs) if joint_action: robot.send_action(joint_action) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index abe64599d..ebe7d6afd 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file from lerobot.configs.types import PolicyFeature -from .converters import batch_to_transition, transition_to_batch +from .converters import batch_to_transition, create_transition, transition_to_batch from .core import EnvTransition, TransitionKey # Type variable for generic processor output type @@ -276,6 +276,12 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]): # Always convert input through to_transition transition = self.to_transition(data) + transformed_transition = self._forward(transition) + + # Always use to_output for consistent typing + return self.to_output(transformed_transition) + + def _forward(self, transition: EnvTransition) -> EnvTransition: # Process through all steps for idx, processor_step in enumerate(self.steps): # Apply before hooks @@ -288,9 +294,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]): # Apply after hooks for hook in self.after_step_hooks: hook(idx, transition) - - # Always use to_output for consistent typing - return self.to_output(transition) + return transition def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]: """Yield the intermediate results after each processor step. @@ -763,6 +767,41 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]): features = out return features + def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + transition: EnvTransition = create_transition(observation=observation) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.OBSERVATION] + + def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor: + transition: EnvTransition = create_transition(action=action) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.ACTION] + + def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor: + transition: EnvTransition = create_transition(reward=reward) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.REWARD] + + def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor: + transition: EnvTransition = create_transition(done=done) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.DONE] + + def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor: + transition: EnvTransition = create_transition(truncated=truncated) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.TRUNCATED] + + def process_info(self, info: dict[str, Any]) -> dict[str, Any]: + transition: EnvTransition = create_transition(info=info) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.INFO] + + def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]: + transition: EnvTransition = create_transition(complementary_data=complementary_data) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.COMPLEMENTARY_DATA] + class ObservationProcessor(ProcessorStep, ABC): """Base class for processors that modify only the observation component of a transition.