From 356a64d8c423083240833a8b1ef23bfbba9fb43c Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Sun, 26 Apr 2026 16:36:21 +0200 Subject: [PATCH] fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline (cherry picked from commit 9c2af818ff4bfef2603348e0609aa249c3ff62b1) --- src/lerobot/processor/hil_processor.py | 33 +++++++++++++++----- src/lerobot/rl/gym_manipulator.py | 43 ++++++++++++++++---------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 18859707c..c17441c46 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep): This step normalizes the `transition` object by: 1. Copying `teleop_action` from `info` to `complementary_data`. 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). + 3. Copying `discrete_penalty` from `info` to `complementary_data`. """ def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep): if TELEOP_ACTION_KEY in info: complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] + if DISCRETE_PENALTY_KEY in info: + complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY] + if "is_intervention" in info: info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] @@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep): @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ProcessorStep): """ - Applies a penalty for inefficient gripper usage. + Applies a small per-transition cost on the discrete gripper action. - This step penalizes actions that attempt to close an already closed gripper or - open an already open one, based on position thresholds. + Fires only when the commanded action would actually transition the gripper + from one extreme to the other (close-while-open or open-while-closed). + This discourages gripper oscillation while leaving "stay" and saturating-further + commands unpenalized. Attributes: penalty: The negative reward value to apply. max_gripper_pos: The maximum position value for the gripper, used for normalization. + open_threshold: Normalized state below which the gripper is considered "open". + closed_threshold: Normalized state above which the gripper is considered "closed". """ - penalty: float = -0.01 + penalty: float = -0.02 max_gripper_pos: float = 30.0 + open_threshold: float = 0.1 + closed_threshold: float = 0.9 def __call__(self, transition: EnvTransition) -> EnvTransition: """ @@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep): gripper_state_normalized = current_gripper_pos / self.max_gripper_pos # Calculate penalty boolean as in original - gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( - gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 - ) + # - currently open AND target is closed -> close transition + # - currently closed AND target is open -> open transition + is_open = gripper_state_normalized < self.open_threshold + is_closed = gripper_state_normalized > self.closed_threshold + cmd_close = gripper_action_normalized > self.closed_threshold + cmd_open = gripper_action_normalized < self.open_threshold + gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open) gripper_penalty = self.penalty * int(gripper_penalty_bool) @@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep): Returns the configuration of the step for serialization. Returns: - A dictionary containing the penalty value and max gripper position. + A dictionary containing the penalty value, max gripper position, + and the open/closed thresholds. """ return { "penalty": self.penalty, "max_gripper_pos": self.max_gripper_pos, + "open_threshold": self.open_threshold, + "closed_threshold": self.closed_threshold, } def reset(self) -> None: diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 0d5bf1982..551ecf41e 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -551,6 +551,12 @@ def step_env_and_process_transition( terminated = terminated or processed_action_transition[TransitionKey.DONE] truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() + + if hasattr(env, "get_raw_joint_positions"): + raw_joint_positions = env.get_raw_joint_positions() + if raw_joint_positions is not None: + complementary_data["raw_joint_positions"] = raw_joint_positions + new_info = info.copy() new_info.update(processed_action_transition[TransitionKey.INFO]) @@ -568,6 +574,24 @@ def step_env_and_process_transition( return new_transition +def reset_and_build_transition( + env: gym.Env, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], +) -> EnvTransition: + """Reset env + processors and return the first env-processed transition.""" + obs, info = env.reset() + env_processor.reset() + action_processor.reset() + complementary_data: dict[str, Any] = {} + if hasattr(env, "get_raw_joint_positions"): + raw_joint_positions = env.get_raw_joint_positions() + if raw_joint_positions is not None: + complementary_data["raw_joint_positions"] = raw_joint_positions + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + return env_processor(data=transition) + + def control_loop( env: gym.Env, env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], @@ -593,17 +617,7 @@ def control_loop( print("- When not intervening, robot will stay still") print("- Press Ctrl+C to exit") - # Reset environment and processors - obs, info = env.reset() - complementary_data = ( - {"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {} - ) - env_processor.reset() - action_processor.reset() - - # Process initial observation - transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) - transition = env_processor(data=transition) + transition = reset_and_build_transition(env, env_processor, action_processor) # Determine if gripper is used use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True @@ -726,12 +740,7 @@ def control_loop( dataset.save_episode() # Reset for new episode - obs, info = env.reset() - env_processor.reset() - action_processor.reset() - - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + transition = reset_and_build_transition(env, env_processor, action_processor) # Maintain fps timing precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))