refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)

This commit is contained in:
Steven Palma
2025-09-02 17:57:49 +02:00
committed by GitHub
parent 2914ae2a96
commit ebb464c255
2 changed files with 23 additions and 21 deletions
+6 -6
View File
@@ -282,15 +282,16 @@ class RewardClassifierProcessor(ProcessorStep):
self.reward_classifier.eval() self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION) new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None: if observation is None or self.reward_classifier is None:
return transition return new_transition
# Extract images from observation # Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key} images = {key: value for key, value in observation.items() if "image" in key}
if not images: if not images:
return transition return new_transition
# Run reward classifier # Run reward classifier
start_time = time.perf_counter() start_time = time.perf_counter()
@@ -300,8 +301,8 @@ class RewardClassifierProcessor(ProcessorStep):
classifier_frequency = 1 / (time.perf_counter() - start_time) classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination # Calculate reward and termination
reward = transition.get(TransitionKey.REWARD, 0.0) reward = new_transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False) terminated = new_transition.get(TransitionKey.DONE, False)
if math.isclose(success, 1, abs_tol=1e-2): if math.isclose(success, 1, abs_tol=1e-2):
reward = self.success_reward reward = self.success_reward
@@ -309,7 +310,6 @@ class RewardClassifierProcessor(ProcessorStep):
terminated = True terminated = True
# Update transition # Update transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated new_transition[TransitionKey.DONE] = terminated
@@ -251,8 +251,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
initial_guess_current_joints: bool = True initial_guess_current_joints: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
act = transition.get(TransitionKey.ACTION) or {} new_transition = transition.copy()
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get(f"{ACTION}.ee.x", None) x = act.get(f"{ACTION}.ee.x", None)
y = act.get(f"{ACTION}.ee.y", None) y = act.get(f"{ACTION}.ee.y", None)
@@ -262,7 +263,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
wz = act.get(f"{ACTION}.ee.wz", None) wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz): if None in (x, y, z, wx, wy, wz):
return transition return new_transition
# Get joint positions from complimentary data # Get joint positions from complimentary data
raw = comp.get("raw_joint_positions", None) raw = comp.get("raw_joint_positions", None)
@@ -292,10 +293,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"]) new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
else: else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i]) new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act new_transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints: if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
@@ -333,18 +334,19 @@ class GripperVelocityToJoint(ProcessorStep):
discrete_gripper: bool = False discrete_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION) or {} new_transition = transition.copy()
act = transition.get(TransitionKey.ACTION) or {} obs = new_transition.get(TransitionKey.OBSERVATION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} act = new_transition.get(TransitionKey.ACTION) or {}
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act: if f"{ACTION}.gripper" not in act:
return transition return new_transition
if "gripper" not in self.motor_names: if "gripper" not in self.motor_names:
new_act = dict(act) new_act = dict(act)
new_act.pop(f"{ACTION}.gripper", None) new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act new_transition[TransitionKey.ACTION] = new_act
return transition return new_transition
if self.discrete_gripper: if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2] # Discrete gripper actions are in [0, 1, 2]
@@ -367,11 +369,11 @@ class GripperVelocityToJoint(ProcessorStep):
new_act = dict(act) new_act = dict(act)
new_act[f"{ACTION}.gripper.pos"] = gripper_pos new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop(f"{ACTION}.gripper", None) new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act new_transition[TransitionKey.ACTION] = new_act
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
transition[TransitionKey.OBSERVATION] = obs new_transition[TransitionKey.OBSERVATION] = obs
return transition return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None) features.pop(f"{ACTION}.gripper", None)