mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
refactor(processor): transform_features loop + EAFP (#1932)
This commit is contained in:
@@ -59,9 +59,11 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for axis in ["x", "y", "z"]:
|
||||
features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
if self.use_gripper:
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
@@ -94,10 +96,10 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop("delta_x", 0.0)
|
||||
delta_y = action.pop("delta_y", 0.0)
|
||||
delta_z = action.pop("delta_z", 0.0)
|
||||
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0)
|
||||
delta_x = action.pop("delta_x")
|
||||
delta_y = action.pop("delta_y")
|
||||
delta_z = action.pop("delta_z")
|
||||
gripper = action.pop("gripper")
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
@@ -132,18 +134,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Transform features to match output format."""
|
||||
features[PipelineFeatureType.ACTION].pop("delta_x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("delta_y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("delta_z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("gripper", None)
|
||||
for axis in ["x", "y", "z", "gripper"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||
|
||||
for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
|
||||
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
@@ -78,7 +78,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
raw = comp["raw_joint_positions"]
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
@@ -155,24 +155,23 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("enabled", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wx", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wy", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wz", None)
|
||||
features[PipelineFeatureType.ACTION].pop("gripper_vel", None)
|
||||
for feat in [
|
||||
"enabled",
|
||||
"target_x",
|
||||
"target_y",
|
||||
"target_z",
|
||||
"target_wx",
|
||||
"target_wy",
|
||||
"target_wz",
|
||||
"gripper_vel",
|
||||
]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"{feat}", None)
|
||||
|
||||
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]:
|
||||
features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.gripper_vel"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@@ -200,12 +199,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
x = action.get("ee.x")
|
||||
y = action.get("ee.y")
|
||||
z = action.get("ee.z")
|
||||
wx = action.get("ee.wx")
|
||||
wy = action.get("ee.wy")
|
||||
wz = action.get("ee.wz")
|
||||
x = action["ee.x"]
|
||||
y = action["ee.y"]
|
||||
z = action["ee.z"]
|
||||
wx = action["ee.wx"]
|
||||
wy = action["ee.wy"]
|
||||
wz = action["ee.wz"]
|
||||
# TODO(Steven): ee.gripper_vel does not need to be bounded
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
@@ -294,7 +293,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
raw = comp["raw_joint_positions"]
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
@@ -330,13 +329,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("ee.x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wx", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wy", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wz", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_pos", None)
|
||||
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
|
||||
|
||||
for name in self.motor_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
@@ -373,7 +368,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
gripper_vel = action.pop("ee.gripper_vel")
|
||||
|
||||
@@ -382,7 +377,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
|
||||
)
|
||||
|
||||
curr_gripper_pos = complementary_data.get("raw_joint_positions").get("gripper")
|
||||
curr_gripper_pos = complementary_data["raw_joint_positions"]["gripper"]
|
||||
|
||||
# TODO(Michel,Adil): Fix this logic
|
||||
# if self.discrete_gripper:
|
||||
@@ -403,7 +398,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel")
|
||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None)
|
||||
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
@@ -428,14 +423,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
motor_joint_values = [observation.get(f"{n}.pos") for n in self.motor_names]
|
||||
motor_joint_values = [observation[f"{n}.pos"] for n in self.motor_names]
|
||||
|
||||
q = np.array(motor_joint_values, dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
gripper_pos = observation.get("gripper.pos")
|
||||
gripper_pos = observation["gripper.pos"]
|
||||
|
||||
for n in self.motor_names:
|
||||
observation.pop(f"{n}.pos")
|
||||
|
||||
@@ -56,10 +56,10 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
ValueError: If 'pos' or 'rot' keys are missing from the input action.
|
||||
"""
|
||||
# Pop them from the action
|
||||
enabled = bool(action.pop("phone.enabled", 0))
|
||||
pos = action.pop("phone.pos", None)
|
||||
rot = action.pop("phone.rot", None)
|
||||
inputs = action.pop("phone.raw_inputs", {})
|
||||
enabled = bool(action.pop("phone.enabled"))
|
||||
pos = action.pop("phone.pos")
|
||||
rot = action.pop("phone.rot")
|
||||
inputs = action.pop("phone.raw_inputs")
|
||||
|
||||
if pos is None or rot is None:
|
||||
raise ValueError("pos and rot must be present in action")
|
||||
@@ -90,19 +90,21 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("phone.enabled", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.pos", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None)
|
||||
for feat in ["enabled", "pos", "rot", "raw_inputs"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None)
|
||||
|
||||
for feat in [
|
||||
"enabled",
|
||||
"target_x",
|
||||
"target_y",
|
||||
"target_z",
|
||||
"target_wx",
|
||||
"target_wy",
|
||||
"target_wz",
|
||||
"gripper_vel",
|
||||
]:
|
||||
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper_vel"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user