mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
refactor(processor): transform_features loop + EAFP (#1932)
This commit is contained in:
@@ -59,9 +59,11 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
for axis in ["x", "y", "z"]:
|
||||||
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
|
||||||
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
||||||
type=FeatureType.ACTION, shape=(1,)
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
@@ -94,10 +96,10 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
|||||||
def action(self, action: RobotAction) -> RobotAction:
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||||
delta_x = action.pop("delta_x", 0.0)
|
delta_x = action.pop("delta_x")
|
||||||
delta_y = action.pop("delta_y", 0.0)
|
delta_y = action.pop("delta_y")
|
||||||
delta_z = action.pop("delta_z", 0.0)
|
delta_z = action.pop("delta_z")
|
||||||
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0)
|
gripper = action.pop("gripper")
|
||||||
|
|
||||||
# Determine if the teleoperator is actively providing input
|
# Determine if the teleoperator is actively providing input
|
||||||
# Consider enabled if any significant movement delta is detected
|
# Consider enabled if any significant movement delta is detected
|
||||||
@@ -132,18 +134,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
"""Transform features to match output format."""
|
for axis in ["x", "y", "z", "gripper"]:
|
||||||
features[PipelineFeatureType.ACTION].pop("delta_x", None)
|
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||||
features[PipelineFeatureType.ACTION].pop("delta_y", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("delta_z", None)
|
for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
|
||||||
features[PipelineFeatureType.ACTION].pop("gripper", None)
|
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
|
)
|
||||||
|
|
||||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
|||||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
|
||||||
# Get joint positions from complimentary data
|
# Get joint positions from complimentary data
|
||||||
raw = comp.get("raw_joint_positions", None)
|
raw = comp["raw_joint_positions"]
|
||||||
if raw is None:
|
if raw is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||||
@@ -155,24 +155,23 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
features[PipelineFeatureType.ACTION].pop("enabled", None)
|
for feat in [
|
||||||
features[PipelineFeatureType.ACTION].pop("target_x", None)
|
"enabled",
|
||||||
features[PipelineFeatureType.ACTION].pop("target_y", None)
|
"target_x",
|
||||||
features[PipelineFeatureType.ACTION].pop("target_z", None)
|
"target_y",
|
||||||
features[PipelineFeatureType.ACTION].pop("target_wx", None)
|
"target_z",
|
||||||
features[PipelineFeatureType.ACTION].pop("target_wy", None)
|
"target_wx",
|
||||||
features[PipelineFeatureType.ACTION].pop("target_wz", None)
|
"target_wy",
|
||||||
features[PipelineFeatureType.ACTION].pop("gripper_vel", None)
|
"target_wz",
|
||||||
|
"gripper_vel",
|
||||||
|
]:
|
||||||
|
features[PipelineFeatureType.ACTION].pop(f"{feat}", None)
|
||||||
|
|
||||||
|
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]:
|
||||||
|
features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
|
)
|
||||||
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["ee.gripper_vel"] = PolicyFeature(
|
|
||||||
type=FeatureType.ACTION, shape=(1,)
|
|
||||||
)
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@@ -200,12 +199,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
|
|||||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
def action(self, action: RobotAction) -> RobotAction:
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
x = action.get("ee.x")
|
x = action["ee.x"]
|
||||||
y = action.get("ee.y")
|
y = action["ee.y"]
|
||||||
z = action.get("ee.z")
|
z = action["ee.z"]
|
||||||
wx = action.get("ee.wx")
|
wx = action["ee.wx"]
|
||||||
wy = action.get("ee.wy")
|
wy = action["ee.wy"]
|
||||||
wz = action.get("ee.wz")
|
wz = action["ee.wz"]
|
||||||
# TODO(Steven): ee.gripper_vel does not need to be bounded
|
# TODO(Steven): ee.gripper_vel does not need to be bounded
|
||||||
|
|
||||||
if None in (x, y, z, wx, wy, wz):
|
if None in (x, y, z, wx, wy, wz):
|
||||||
@@ -294,7 +293,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get joint positions from complimentary data
|
# Get joint positions from complimentary data
|
||||||
raw = comp.get("raw_joint_positions", None)
|
raw = comp["raw_joint_positions"]
|
||||||
if raw is None:
|
if raw is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||||
@@ -330,13 +329,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.x", None)
|
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.y", None)
|
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.z", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.wx", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.wy", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.wz", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_pos", None)
|
|
||||||
for name in self.motor_names:
|
for name in self.motor_names:
|
||||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||||
type=FeatureType.ACTION, shape=(1,)
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
@@ -373,7 +368,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
|||||||
discrete_gripper: bool = False
|
discrete_gripper: bool = False
|
||||||
|
|
||||||
def action(self, action: RobotAction) -> RobotAction:
|
def action(self, action: RobotAction) -> RobotAction:
|
||||||
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
|
||||||
gripper_vel = action.pop("ee.gripper_vel")
|
gripper_vel = action.pop("ee.gripper_vel")
|
||||||
|
|
||||||
@@ -382,7 +377,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
|||||||
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
|
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_gripper_pos = complementary_data.get("raw_joint_positions").get("gripper")
|
curr_gripper_pos = complementary_data["raw_joint_positions"]["gripper"]
|
||||||
|
|
||||||
# TODO(Michel,Adil): Fix this logic
|
# TODO(Michel,Adil): Fix this logic
|
||||||
# if self.discrete_gripper:
|
# if self.discrete_gripper:
|
||||||
@@ -403,7 +398,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel")
|
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None)
|
||||||
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
|
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
|
||||||
type=FeatureType.ACTION, shape=(1,)
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
)
|
)
|
||||||
@@ -428,14 +423,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
|||||||
motor_names: list[str]
|
motor_names: list[str]
|
||||||
|
|
||||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
motor_joint_values = [observation.get(f"{n}.pos") for n in self.motor_names]
|
motor_joint_values = [observation[f"{n}.pos"] for n in self.motor_names]
|
||||||
|
|
||||||
q = np.array(motor_joint_values, dtype=float)
|
q = np.array(motor_joint_values, dtype=float)
|
||||||
t = self.kinematics.forward_kinematics(q)
|
t = self.kinematics.forward_kinematics(q)
|
||||||
pos = t[:3, 3]
|
pos = t[:3, 3]
|
||||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||||
|
|
||||||
gripper_pos = observation.get("gripper.pos")
|
gripper_pos = observation["gripper.pos"]
|
||||||
|
|
||||||
for n in self.motor_names:
|
for n in self.motor_names:
|
||||||
observation.pop(f"{n}.pos")
|
observation.pop(f"{n}.pos")
|
||||||
|
|||||||
@@ -56,10 +56,10 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
|||||||
ValueError: If 'pos' or 'rot' keys are missing from the input action.
|
ValueError: If 'pos' or 'rot' keys are missing from the input action.
|
||||||
"""
|
"""
|
||||||
# Pop them from the action
|
# Pop them from the action
|
||||||
enabled = bool(action.pop("phone.enabled", 0))
|
enabled = bool(action.pop("phone.enabled"))
|
||||||
pos = action.pop("phone.pos", None)
|
pos = action.pop("phone.pos")
|
||||||
rot = action.pop("phone.rot", None)
|
rot = action.pop("phone.rot")
|
||||||
inputs = action.pop("phone.raw_inputs", {})
|
inputs = action.pop("phone.raw_inputs")
|
||||||
|
|
||||||
if pos is None or rot is None:
|
if pos is None or rot is None:
|
||||||
raise ValueError("pos and rot must be present in action")
|
raise ValueError("pos and rot must be present in action")
|
||||||
@@ -90,19 +90,21 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
|||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
features[PipelineFeatureType.ACTION].pop("phone.enabled", None)
|
for feat in ["enabled", "pos", "rot", "raw_inputs"]:
|
||||||
features[PipelineFeatureType.ACTION].pop("phone.pos", None)
|
features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None)
|
||||||
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
|
|
||||||
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None)
|
for feat in [
|
||||||
|
"enabled",
|
||||||
|
"target_x",
|
||||||
|
"target_y",
|
||||||
|
"target_z",
|
||||||
|
"target_wx",
|
||||||
|
"target_wy",
|
||||||
|
"target_wz",
|
||||||
|
"gripper_vel",
|
||||||
|
]:
|
||||||
|
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION, shape=(1,)
|
||||||
|
)
|
||||||
|
|
||||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
|
||||||
features[PipelineFeatureType.ACTION]["gripper_vel"] = PolicyFeature(
|
|
||||||
type=FeatureType.ACTION, shape=(1,)
|
|
||||||
)
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
Reference in New Issue
Block a user