refactor(processor): transform_features loop + EAFP (#1932)

This commit is contained in:
Steven Palma
2025-09-14 16:07:32 +02:00
committed by GitHub
parent 50293bb17b
commit c69f23723e
3 changed files with 68 additions and 75 deletions
+16 -20
View File
@@ -59,9 +59,11 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for axis in ["x", "y", "z"]:
features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
if self.use_gripper:
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
@@ -94,10 +96,10 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
def action(self, action: RobotAction) -> RobotAction:
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
delta_x = action.pop("delta_x", 0.0)
delta_y = action.pop("delta_y", 0.0)
delta_z = action.pop("delta_z", 0.0)
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0)
delta_x = action.pop("delta_x")
delta_y = action.pop("delta_y")
delta_z = action.pop("delta_z")
gripper = action.pop("gripper")
# Determine if the teleoperator is actively providing input
# Consider enabled if any significant movement delta is detected
@@ -132,18 +134,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Transform features to match output format."""
features[PipelineFeatureType.ACTION].pop("delta_x", None)
features[PipelineFeatureType.ACTION].pop("delta_y", None)
features[PipelineFeatureType.ACTION].pop("delta_z", None)
features[PipelineFeatureType.ACTION].pop("gripper", None)
for axis in ["x", "y", "z", "gripper"]:
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@@ -78,7 +78,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
raw = comp["raw_joint_positions"]
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
@@ -155,24 +155,23 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("enabled", None)
features[PipelineFeatureType.ACTION].pop("target_x", None)
features[PipelineFeatureType.ACTION].pop("target_y", None)
features[PipelineFeatureType.ACTION].pop("target_z", None)
features[PipelineFeatureType.ACTION].pop("target_wx", None)
features[PipelineFeatureType.ACTION].pop("target_wy", None)
features[PipelineFeatureType.ACTION].pop("target_wz", None)
features[PipelineFeatureType.ACTION].pop("gripper_vel", None)
for feat in [
"enabled",
"target_x",
"target_y",
"target_z",
"target_wx",
"target_wy",
"target_wz",
"gripper_vel",
]:
features[PipelineFeatureType.ACTION].pop(f"{feat}", None)
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]:
features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["ee.gripper_vel"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features
@@ -200,12 +199,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action: RobotAction) -> RobotAction:
x = action.get("ee.x")
y = action.get("ee.y")
z = action.get("ee.z")
wx = action.get("ee.wx")
wy = action.get("ee.wy")
wz = action.get("ee.wz")
x = action["ee.x"]
y = action["ee.y"]
z = action["ee.z"]
wx = action["ee.wx"]
wy = action["ee.wy"]
wz = action["ee.wz"]
# TODO(Steven): ee.gripper_vel does not need to be bounded
if None in (x, y, z, wx, wy, wz):
@@ -294,7 +293,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
)
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
raw = comp["raw_joint_positions"]
if raw is None:
raise ValueError(
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
@@ -330,13 +329,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("ee.x", None)
features[PipelineFeatureType.ACTION].pop("ee.y", None)
features[PipelineFeatureType.ACTION].pop("ee.z", None)
features[PipelineFeatureType.ACTION].pop("ee.wx", None)
features[PipelineFeatureType.ACTION].pop("ee.wy", None)
features[PipelineFeatureType.ACTION].pop("ee.wz", None)
features[PipelineFeatureType.ACTION].pop("ee.gripper_pos", None)
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
for name in self.motor_names:
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
@@ -373,7 +368,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
discrete_gripper: bool = False
def action(self, action: RobotAction) -> RobotAction:
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
gripper_vel = action.pop("ee.gripper_vel")
@@ -382,7 +377,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
)
curr_gripper_pos = complementary_data.get("raw_joint_positions").get("gripper")
curr_gripper_pos = complementary_data["raw_joint_positions"]["gripper"]
# TODO(Michel,Adil): Fix this logic
# if self.discrete_gripper:
@@ -403,7 +398,7 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel")
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None)
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
@@ -428,14 +423,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
motor_names: list[str]
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
motor_joint_values = [observation.get(f"{n}.pos") for n in self.motor_names]
motor_joint_values = [observation[f"{n}.pos"] for n in self.motor_names]
q = np.array(motor_joint_values, dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
gripper_pos = observation.get("gripper.pos")
gripper_pos = observation["gripper.pos"]
for n in self.motor_names:
observation.pop(f"{n}.pos")
@@ -56,10 +56,10 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
ValueError: If 'pos' or 'rot' keys are missing from the input action.
"""
# Pop them from the action
enabled = bool(action.pop("phone.enabled", 0))
pos = action.pop("phone.pos", None)
rot = action.pop("phone.rot", None)
inputs = action.pop("phone.raw_inputs", {})
enabled = bool(action.pop("phone.enabled"))
pos = action.pop("phone.pos")
rot = action.pop("phone.rot")
inputs = action.pop("phone.raw_inputs")
if pos is None or rot is None:
raise ValueError("pos and rot must be present in action")
@@ -90,19 +90,21 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
features[PipelineFeatureType.ACTION].pop("phone.enabled", None)
features[PipelineFeatureType.ACTION].pop("phone.pos", None)
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None)
for feat in ["enabled", "pos", "rot", "raw_inputs"]:
features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None)
for feat in [
"enabled",
"target_x",
"target_y",
"target_z",
"target_wx",
"target_wy",
"target_wz",
"gripper_vel",
]:
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[PipelineFeatureType.ACTION]["gripper_vel"] = PolicyFeature(
type=FeatureType.ACTION, shape=(1,)
)
return features