refactor(processor): transform_features loop + EAFP (#1932)

This commit is contained in:
Steven Palma
2025-09-14 16:07:32 +02:00
committed by GitHub
parent 50293bb17b
commit c69f23723e
3 changed files with 68 additions and 75 deletions
+16 -20
View File
@@ -59,9 +59,11 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,)) for axis in ["x", "y", "z"]:
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,)) features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,)) type=FeatureType.ACTION, shape=(1,)
)
if self.use_gripper: if self.use_gripper:
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature( features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,) type=FeatureType.ACTION, shape=(1,)
@@ -94,10 +96,10 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
def action(self, action: RobotAction) -> RobotAction: def action(self, action: RobotAction) -> RobotAction:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop("delta_x", 0.0) delta_x = action.pop("delta_x")
delta_y = action.pop("delta_y", 0.0) delta_y = action.pop("delta_y")
delta_z = action.pop("delta_z", 0.0) delta_z = action.pop("delta_z")
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0) gripper = action.pop("gripper")
# Determine if the teleoperator is actively providing input # Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected # Consider enabled if any significant movement delta is detected
@@ -132,18 +134,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Transform features to match output format.""" for axis in ["x", "y", "z", "gripper"]:
features[PipelineFeatureType.ACTION].pop("delta_x", None) features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
features[PipelineFeatureType.ACTION].pop("delta_y", None)
features[PipelineFeatureType.ACTION].pop("delta_z", None) for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
features[PipelineFeatureType.ACTION].pop("gripper", None) features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features return features
@@ -78,7 +78,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
# Get joint positions from complimentary data # Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None) raw = comp["raw_joint_positions"]
if raw is None: if raw is None:
raise ValueError( raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
@@ -155,24 +155,23 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("enabled", None) for feat in [
features[PipelineFeatureType.ACTION].pop("target_x", None) "enabled",
features[PipelineFeatureType.ACTION].pop("target_y", None) "target_x",
features[PipelineFeatureType.ACTION].pop("target_z", None) "target_y",
features[PipelineFeatureType.ACTION].pop("target_wx", None) "target_z",
features[PipelineFeatureType.ACTION].pop("target_wy", None) "target_wx",
features[PipelineFeatureType.ACTION].pop("target_wz", None) "target_wy",
features[PipelineFeatureType.ACTION].pop("gripper_vel", None) "target_wz",
"gripper_vel",
]:
features[PipelineFeatureType.ACTION].pop(f"{feat}", None)
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]:
features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.gripper_vel"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features return features
@@ -200,12 +199,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False) _last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action: RobotAction) -> RobotAction: def action(self, action: RobotAction) -> RobotAction:
x = action.get("ee.x") x = action["ee.x"]
y = action.get("ee.y") y = action["ee.y"]
z = action.get("ee.z") z = action["ee.z"]
wx = action.get("ee.wx") wx = action["ee.wx"]
wy = action.get("ee.wy") wy = action["ee.wy"]
wz = action.get("ee.wz") wz = action["ee.wz"]
# TODO(Steven): ee.gripper_vel does not need to be bounded # TODO(Steven): ee.gripper_vel does not need to be bounded
if None in (x, y, z, wx, wy, wz): if None in (x, y, z, wx, wy, wz):
@@ -294,7 +293,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
) )
# Get joint positions from complimentary data # Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None) raw = comp["raw_joint_positions"]
if raw is None: if raw is None:
raise ValueError( raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta" "raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
@@ -330,13 +329,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("ee.x", None) for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
features[PipelineFeatureType.ACTION].pop("ee.y", None) features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
features[PipelineFeatureType.ACTION].pop("ee.z", None)
features[PipelineFeatureType.ACTION].pop("ee.wx", None)
features[PipelineFeatureType.ACTION].pop("ee.wy", None)
features[PipelineFeatureType.ACTION].pop("ee.wz", None)
features[PipelineFeatureType.ACTION].pop("ee.gripper_pos", None)
for name in self.motor_names: for name in self.motor_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,) type=FeatureType.ACTION, shape=(1,)
@@ -373,7 +368,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
discrete_gripper: bool = False discrete_gripper: bool = False
def action(self, action: RobotAction) -> RobotAction: def action(self, action: RobotAction) -> RobotAction:
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
gripper_vel = action.pop("ee.gripper_vel") gripper_vel = action.pop("ee.gripper_vel")
@@ -382,7 +377,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint" "raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
) )
curr_gripper_pos = complementary_data.get("raw_joint_positions").get("gripper") curr_gripper_pos = complementary_data["raw_joint_positions"]["gripper"]
# TODO(Michel,Adil): Fix this logic # TODO(Michel,Adil): Fix this logic
# if self.discrete_gripper: # if self.discrete_gripper:
@@ -403,7 +398,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel") features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None)
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature( features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,) type=FeatureType.ACTION, shape=(1,)
) )
@@ -428,14 +423,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
motor_names: list[str] motor_names: list[str]
def observation(self, observation: dict[str, Any]) -> dict[str, Any]: def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
motor_joint_values = [observation.get(f"{n}.pos") for n in self.motor_names] motor_joint_values = [observation[f"{n}.pos"] for n in self.motor_names]
q = np.array(motor_joint_values, dtype=float) q = np.array(motor_joint_values, dtype=float)
t = self.kinematics.forward_kinematics(q) t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3] pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
gripper_pos = observation.get("gripper.pos") gripper_pos = observation["gripper.pos"]
for n in self.motor_names: for n in self.motor_names:
observation.pop(f"{n}.pos") observation.pop(f"{n}.pos")
@@ -56,10 +56,10 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
ValueError: If 'pos' or 'rot' keys are missing from the input action. ValueError: If 'pos' or 'rot' keys are missing from the input action.
""" """
# Pop them from the action # Pop them from the action
enabled = bool(action.pop("phone.enabled", 0)) enabled = bool(action.pop("phone.enabled"))
pos = action.pop("phone.pos", None) pos = action.pop("phone.pos")
rot = action.pop("phone.rot", None) rot = action.pop("phone.rot")
inputs = action.pop("phone.raw_inputs", {}) inputs = action.pop("phone.raw_inputs")
if pos is None or rot is None: if pos is None or rot is None:
raise ValueError("pos and rot must be present in action") raise ValueError("pos and rot must be present in action")
@@ -90,19 +90,21 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
def transform_features( def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("phone.enabled", None) for feat in ["enabled", "pos", "rot", "raw_inputs"]:
features[PipelineFeatureType.ACTION].pop("phone.pos", None) features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None)
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None) for feat in [
"enabled",
"target_x",
"target_y",
"target_z",
"target_wx",
"target_wy",
"target_wz",
"gripper_vel",
]:
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper_vel"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features return features