refactor(processors): add extended api for specialized pipelines (#1848)

This commit is contained in:
Steven Palma
2025-09-03 12:28:40 +02:00
committed by GitHub
parent b052843f08
commit 2fcc358e98
5 changed files with 62 additions and 23 deletions
+6 -6
View File
@@ -65,7 +65,7 @@ kinematics_solver = RobotKinematics(
) )
# Build pipeline to convert ee pose action to joint action # Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor( robot_ee_to_joints_processor = RobotProcessor(
steps=[ steps=[
AddRobotObservationAsComplimentaryData(robot=robot), AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
@@ -79,7 +79,7 @@ robot_ee_to_joints = RobotProcessor(
) )
# Build pipeline to convert joint observation to ee pose observation # Build pipeline to convert joint observation to ee pose observation
robot_joints_to_ee_pose = RobotProcessor( robot_joints_to_ee_pose_processor = RobotProcessor(
steps=[ steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
], ],
@@ -89,7 +89,7 @@ robot_joints_to_ee_pose = RobotProcessor(
# Build dataset action and gripper features # Build dataset action and gripper features
action_ee_and_gripper = aggregate_pipeline_dataset_features( action_ee_and_gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints, pipeline=robot_ee_to_joints_processor,
initial_features={}, initial_features={},
use_videos=True, use_videos=True,
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"], patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
@@ -97,7 +97,7 @@ action_ee_and_gripper = aggregate_pipeline_dataset_features(
# Build dataset observation features # Build dataset observation features
obs_ee = aggregate_pipeline_dataset_features( obs_ee = aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose, pipeline=robot_joints_to_ee_pose_processor,
initial_features=robot.observation_features, initial_features=robot.observation_features,
use_videos=True, use_videos=True,
patterns=["observation.state.ee"], patterns=["observation.state.ee"],
@@ -147,8 +147,8 @@ for episode_idx in range(NUM_EPISODES):
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
robot_action_processor=robot_ee_to_joints, robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose, robot_observation_processor=robot_joints_to_ee_pose_processor,
) )
dataset.save_episode() dataset.save_episode()
+8 -8
View File
@@ -73,7 +73,7 @@ kinematics_solver = RobotKinematics(
) )
# Build pipeline to convert phone action to ee pose action # Build pipeline to convert phone action to ee pose action
phone_to_robot_ee_pose = RobotProcessor( phone_to_robot_ee_pose_processor = RobotProcessor(
steps=[ steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os), MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot), AddRobotObservationAsComplimentaryData(robot=robot),
@@ -93,7 +93,7 @@ phone_to_robot_ee_pose = RobotProcessor(
) )
# Build pipeline to convert ee pose action to joint action # Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor( robot_ee_to_joints_processor = RobotProcessor(
steps=[ steps=[
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
kinematics=kinematics_solver, kinematics=kinematics_solver,
@@ -120,7 +120,7 @@ robot_joints_to_ee_pose = RobotProcessor(
# Build dataset ee action features # Build dataset ee action features
action_ee = aggregate_pipeline_dataset_features( action_ee = aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose, pipeline=phone_to_robot_ee_pose_processor,
initial_features=phone.action_features, initial_features=phone.action_features,
use_videos=True, use_videos=True,
patterns=["action.ee"], patterns=["action.ee"],
@@ -128,7 +128,7 @@ action_ee = aggregate_pipeline_dataset_features(
# Get gripper pos action features # Get gripper pos action features
gripper = aggregate_pipeline_dataset_features( gripper = aggregate_pipeline_dataset_features(
pipeline=robot_ee_to_joints, pipeline=robot_ee_to_joints_processor,
initial_features={}, initial_features={},
use_videos=True, use_videos=True,
patterns=["action.gripper.pos", "observation.state.gripper.pos"], patterns=["action.gripper.pos", "observation.state.gripper.pos"],
@@ -177,8 +177,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=phone_to_robot_ee_pose, teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints, robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose, robot_observation_processor=robot_joints_to_ee_pose,
) )
@@ -193,8 +193,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
control_time_s=RESET_TIME_SEC, control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION, single_task=TASK_DESCRIPTION,
display_data=True, display_data=True,
teleop_action_processor=phone_to_robot_ee_pose, teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints, robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose, robot_observation_processor=robot_joints_to_ee_pose,
) )
+3 -3
View File
@@ -50,7 +50,7 @@ kinematics_solver = RobotKinematics(
) )
# Build pipeline to convert ee pose action to joint action # Build pipeline to convert ee pose action to joint action
robot_ee_to_joints = RobotProcessor( robot_ee_to_joints_processor = RobotProcessor(
steps=[ steps=[
AddRobotObservationAsComplimentaryData(robot=robot), AddRobotObservationAsComplimentaryData(robot=robot),
InverseKinematicsEEToJoints( InverseKinematicsEEToJoints(
@@ -63,7 +63,7 @@ robot_ee_to_joints = RobotProcessor(
to_output=to_output_robot_action, to_output=to_output_robot_action,
) )
robot_ee_to_joints.reset() robot_ee_to_joints_processor.reset()
log_say(f"Replaying episode {EPISODE_IDX}") log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(dataset.num_frames): for idx in range(dataset.num_frames):
@@ -73,7 +73,7 @@ for idx in range(dataset.num_frames):
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
} }
joint_action = robot_ee_to_joints(ee_action) joint_action = robot_ee_to_joints_processor(ee_action)
action_sent = robot.send_action(joint_action) action_sent = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
+2 -2
View File
@@ -49,7 +49,7 @@ kinematics_solver = RobotKinematics(
) )
# Build pipeline to convert phone action to ee pose action to joint action # Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints = RobotProcessor( phone_to_robot_joints_processor = RobotProcessor(
steps=[ steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os), MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
AddRobotObservationAsComplimentaryData(robot=robot), AddRobotObservationAsComplimentaryData(robot=robot),
@@ -85,7 +85,7 @@ while True:
phone_obs = teleop_device.get_action() phone_obs = teleop_device.get_action()
# Phone -> EE pose -> Joints transition # Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints(phone_obs) joint_action = phone_to_robot_joints_processor(phone_obs)
if joint_action: if joint_action:
robot.send_action(joint_action) robot.send_action(joint_action)
+43 -4
View File
@@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from .converters import batch_to_transition, transition_to_batch from .converters import batch_to_transition, create_transition, transition_to_batch
from .core import EnvTransition, TransitionKey from .core import EnvTransition, TransitionKey
# Type variable for generic processor output type # Type variable for generic processor output type
@@ -276,6 +276,12 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
# Always convert input through to_transition # Always convert input through to_transition
transition = self.to_transition(data) transition = self.to_transition(data)
transformed_transition = self._forward(transition)
# Always use to_output for consistent typing
return self.to_output(transformed_transition)
def _forward(self, transition: EnvTransition) -> EnvTransition:
# Process through all steps # Process through all steps
for idx, processor_step in enumerate(self.steps): for idx, processor_step in enumerate(self.steps):
# Apply before hooks # Apply before hooks
@@ -288,9 +294,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
# Apply after hooks # Apply after hooks
for hook in self.after_step_hooks: for hook in self.after_step_hooks:
hook(idx, transition) hook(idx, transition)
return transition
# Always use to_output for consistent typing
return self.to_output(transition)
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]: def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step. """Yield the intermediate results after each processor step.
@@ -763,6 +767,41 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
features = out features = out
return features return features
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(observation=observation)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.OBSERVATION]
def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor:
transition: EnvTransition = create_transition(action=action)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.ACTION]
def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
transition: EnvTransition = create_transition(reward=reward)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.REWARD]
def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
transition: EnvTransition = create_transition(done=done)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.DONE]
def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
transition: EnvTransition = create_transition(truncated=truncated)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.TRUNCATED]
def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(info=info)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.INFO]
def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
transition: EnvTransition = create_transition(complementary_data=complementary_data)
transformed_transition = self._forward(transition)
return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
class ObservationProcessor(ProcessorStep, ABC): class ObservationProcessor(ProcessorStep, ABC):
"""Base class for processors that modify only the observation component of a transition. """Base class for processors that modify only the observation component of a transition.