refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)

This commit is contained in:
Steven Palma
2025-09-02 17:57:49 +02:00
committed by GitHub
parent 2914ae2a96
commit ebb464c255
2 changed files with 23 additions and 21 deletions
@@ -251,8 +251,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
new_transition = transition.copy()
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None)
@@ -262,7 +263,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return transition
return new_transition
# Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None)
@@ -292,10 +293,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act
new_transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
@@ -333,18 +334,19 @@ class GripperVelocityToJoint(ProcessorStep):
discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION) or {}
act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act:
return transition
return new_transition
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act
return transition
new_transition[TransitionKey.ACTION] = new_act
return new_transition
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
@@ -367,11 +369,11 @@ class GripperVelocityToJoint(ProcessorStep):
new_act = dict(act)
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act
new_transition[TransitionKey.ACTION] = new_act
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
transition[TransitionKey.OBSERVATION] = obs
return transition
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None)