mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
refactor(processors): add extended api for specialized pipelines (#1848)
This commit is contained in:
@@ -65,7 +65,7 @@ kinematics_solver = RobotKinematics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert ee pose action to joint action
|
# Build pipeline to convert ee pose action to joint action
|
||||||
robot_ee_to_joints = RobotProcessor(
|
robot_ee_to_joints_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||||
InverseKinematicsEEToJoints(
|
InverseKinematicsEEToJoints(
|
||||||
@@ -79,7 +79,7 @@ robot_ee_to_joints = RobotProcessor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert joint observation to ee pose observation
|
# Build pipeline to convert joint observation to ee pose observation
|
||||||
robot_joints_to_ee_pose = RobotProcessor(
|
robot_joints_to_ee_pose_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||||
],
|
],
|
||||||
@@ -89,7 +89,7 @@ robot_joints_to_ee_pose = RobotProcessor(
|
|||||||
|
|
||||||
# Build dataset action and gripper features
|
# Build dataset action and gripper features
|
||||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||||
pipeline=robot_ee_to_joints,
|
pipeline=robot_ee_to_joints_processor,
|
||||||
initial_features={},
|
initial_features={},
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||||
@@ -97,7 +97,7 @@ action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
|||||||
|
|
||||||
# Build dataset observation features
|
# Build dataset observation features
|
||||||
obs_ee = aggregate_pipeline_dataset_features(
|
obs_ee = aggregate_pipeline_dataset_features(
|
||||||
pipeline=robot_joints_to_ee_pose,
|
pipeline=robot_joints_to_ee_pose_processor,
|
||||||
initial_features=robot.observation_features,
|
initial_features=robot.observation_features,
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
patterns=["observation.state.ee"],
|
patterns=["observation.state.ee"],
|
||||||
@@ -147,8 +147,8 @@ for episode_idx in range(NUM_EPISODES):
|
|||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
robot_action_processor=robot_ee_to_joints,
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||||
)
|
)
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ kinematics_solver = RobotKinematics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert phone action to ee pose action
|
# Build pipeline to convert phone action to ee pose action
|
||||||
phone_to_robot_ee_pose = RobotProcessor(
|
phone_to_robot_ee_pose_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||||
@@ -93,7 +93,7 @@ phone_to_robot_ee_pose = RobotProcessor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert ee pose action to joint action
|
# Build pipeline to convert ee pose action to joint action
|
||||||
robot_ee_to_joints = RobotProcessor(
|
robot_ee_to_joints_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
InverseKinematicsEEToJoints(
|
InverseKinematicsEEToJoints(
|
||||||
kinematics=kinematics_solver,
|
kinematics=kinematics_solver,
|
||||||
@@ -120,7 +120,7 @@ robot_joints_to_ee_pose = RobotProcessor(
|
|||||||
|
|
||||||
# Build dataset ee action features
|
# Build dataset ee action features
|
||||||
action_ee = aggregate_pipeline_dataset_features(
|
action_ee = aggregate_pipeline_dataset_features(
|
||||||
pipeline=phone_to_robot_ee_pose,
|
pipeline=phone_to_robot_ee_pose_processor,
|
||||||
initial_features=phone.action_features,
|
initial_features=phone.action_features,
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
patterns=["action.ee"],
|
patterns=["action.ee"],
|
||||||
@@ -128,7 +128,7 @@ action_ee = aggregate_pipeline_dataset_features(
|
|||||||
|
|
||||||
# Get gripper pos action features
|
# Get gripper pos action features
|
||||||
gripper = aggregate_pipeline_dataset_features(
|
gripper = aggregate_pipeline_dataset_features(
|
||||||
pipeline=robot_ee_to_joints,
|
pipeline=robot_ee_to_joints_processor,
|
||||||
initial_features={},
|
initial_features={},
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||||
@@ -177,8 +177,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose,
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
robot_action_processor=robot_ee_to_joints,
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -193,8 +193,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose,
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
robot_action_processor=robot_ee_to_joints,
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ kinematics_solver = RobotKinematics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert ee pose action to joint action
|
# Build pipeline to convert ee pose action to joint action
|
||||||
robot_ee_to_joints = RobotProcessor(
|
robot_ee_to_joints_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||||
InverseKinematicsEEToJoints(
|
InverseKinematicsEEToJoints(
|
||||||
@@ -63,7 +63,7 @@ robot_ee_to_joints = RobotProcessor(
|
|||||||
to_output=to_output_robot_action,
|
to_output=to_output_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
robot_ee_to_joints.reset()
|
robot_ee_to_joints_processor.reset()
|
||||||
|
|
||||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||||
for idx in range(dataset.num_frames):
|
for idx in range(dataset.num_frames):
|
||||||
@@ -73,7 +73,7 @@ for idx in range(dataset.num_frames):
|
|||||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||||
}
|
}
|
||||||
|
|
||||||
joint_action = robot_ee_to_joints(ee_action)
|
joint_action = robot_ee_to_joints_processor(ee_action)
|
||||||
action_sent = robot.send_action(joint_action)
|
action_sent = robot.send_action(joint_action)
|
||||||
|
|
||||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ kinematics_solver = RobotKinematics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert phone action to ee pose action to joint action
|
# Build pipeline to convert phone action to ee pose action to joint action
|
||||||
phone_to_robot_joints = RobotProcessor(
|
phone_to_robot_joints_processor = RobotProcessor(
|
||||||
steps=[
|
steps=[
|
||||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||||
@@ -85,7 +85,7 @@ while True:
|
|||||||
phone_obs = teleop_device.get_action()
|
phone_obs = teleop_device.get_action()
|
||||||
|
|
||||||
# Phone -> EE pose -> Joints transition
|
# Phone -> EE pose -> Joints transition
|
||||||
joint_action = phone_to_robot_joints(phone_obs)
|
joint_action = phone_to_robot_joints_processor(phone_obs)
|
||||||
|
|
||||||
if joint_action:
|
if joint_action:
|
||||||
robot.send_action(joint_action)
|
robot.send_action(joint_action)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from safetensors.torch import load_file, save_file
|
|||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
|
||||||
from .converters import batch_to_transition, transition_to_batch
|
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||||
from .core import EnvTransition, TransitionKey
|
from .core import EnvTransition, TransitionKey
|
||||||
|
|
||||||
# Type variable for generic processor output type
|
# Type variable for generic processor output type
|
||||||
@@ -276,6 +276,12 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
|||||||
# Always convert input through to_transition
|
# Always convert input through to_transition
|
||||||
transition = self.to_transition(data)
|
transition = self.to_transition(data)
|
||||||
|
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
|
||||||
|
# Always use to_output for consistent typing
|
||||||
|
return self.to_output(transformed_transition)
|
||||||
|
|
||||||
|
def _forward(self, transition: EnvTransition) -> EnvTransition:
|
||||||
# Process through all steps
|
# Process through all steps
|
||||||
for idx, processor_step in enumerate(self.steps):
|
for idx, processor_step in enumerate(self.steps):
|
||||||
# Apply before hooks
|
# Apply before hooks
|
||||||
@@ -288,9 +294,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
|||||||
# Apply after hooks
|
# Apply after hooks
|
||||||
for hook in self.after_step_hooks:
|
for hook in self.after_step_hooks:
|
||||||
hook(idx, transition)
|
hook(idx, transition)
|
||||||
|
return transition
|
||||||
# Always use to_output for consistent typing
|
|
||||||
return self.to_output(transition)
|
|
||||||
|
|
||||||
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
|
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
|
||||||
"""Yield the intermediate results after each processor step.
|
"""Yield the intermediate results after each processor step.
|
||||||
@@ -763,6 +767,41 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
|||||||
features = out
|
features = out
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
transition: EnvTransition = create_transition(observation=observation)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor:
|
||||||
|
transition: EnvTransition = create_transition(action=action)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.ACTION]
|
||||||
|
|
||||||
|
def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
|
||||||
|
transition: EnvTransition = create_transition(reward=reward)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.REWARD]
|
||||||
|
|
||||||
|
def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
|
||||||
|
transition: EnvTransition = create_transition(done=done)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.DONE]
|
||||||
|
|
||||||
|
def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
|
||||||
|
transition: EnvTransition = create_transition(truncated=truncated)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.TRUNCATED]
|
||||||
|
|
||||||
|
def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
transition: EnvTransition = create_transition(info=info)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.INFO]
|
||||||
|
|
||||||
|
def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
transition: EnvTransition = create_transition(complementary_data=complementary_data)
|
||||||
|
transformed_transition = self._forward(transition)
|
||||||
|
return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
|
||||||
class ObservationProcessor(ProcessorStep, ABC):
|
class ObservationProcessor(ProcessorStep, ABC):
|
||||||
"""Base class for processors that modify only the observation component of a transition.
|
"""Base class for processors that modify only the observation component of a transition.
|
||||||
|
|||||||
Reference in New Issue
Block a user