Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873)

* fix(RL) add missing config arguments

* respond to copilot review

* fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic.

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
Michel Aractingi
2026-02-02 22:14:03 +01:00
committed by GitHub
parent 9c24a09665
commit 14a15f90e7
3 changed files with 23 additions and 12 deletions
+1
View File
@@ -205,6 +205,7 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
+12 -10
View File
@@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
class GripperPenaltyProcessorStep(ProcessorStep):
"""
Applies a penalty for inefficient gripper usage.
@@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data: dict) -> dict:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
transition: The incoming environment transition.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
The modified transition with the penalty added to complementary data.
"""
action = self.transition.get(TransitionKey.ACTION)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
return new_transition
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
if current_gripper_pos is None:
return complementary_data
return new_transition
# Gripper action is a PolicyAction at this stage
gripper_action = action[-1].item()
@@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Create new complementary data with penalty info
# Update complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""
+10 -2
View File
@@ -412,7 +412,10 @@ def make_processors(
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
if kinematics_solver is not None:
add_ee_pose = (
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
)
if kinematics_solver is not None and add_ee_pose:
env_pipeline_steps.append(
ForwardKinematicsJointsToEEObservation(
kinematics=kinematics_solver,
@@ -435,7 +438,12 @@ def make_processors(
)
# Add gripper penalty processor if gripper config exists and enabled
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
# Only add if max_gripper_pos is explicitly configured (required for normalization)
if (
cfg.processor.gripper is not None
and cfg.processor.gripper.use_gripper
and cfg.processor.max_gripper_pos is not None
):
env_pipeline_steps.append(
GripperPenaltyProcessorStep(
penalty=cfg.processor.gripper.gripper_penalty,