mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline
(cherry picked from commit 9c2af818ff)
This commit is contained in:
@@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
This step normalizes the `transition` object by:
|
This step normalizes the `transition` object by:
|
||||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||||
|
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
if TELEOP_ACTION_KEY in info:
|
if TELEOP_ACTION_KEY in info:
|
||||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||||
|
|
||||||
|
if DISCRETE_PENALTY_KEY in info:
|
||||||
|
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||||
|
|
||||||
if "is_intervention" in info:
|
if "is_intervention" in info:
|
||||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||||
|
|
||||||
@@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||||
"""
|
"""
|
||||||
Applies a penalty for inefficient gripper usage.
|
Applies a small per-transition cost on the discrete gripper action.
|
||||||
|
|
||||||
This step penalizes actions that attempt to close an already closed gripper or
|
Fires only when the commanded action would actually transition the gripper
|
||||||
open an already open one, based on position thresholds.
|
from one extreme to the other (close-while-open or open-while-closed).
|
||||||
|
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||||
|
commands unpenalized.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
penalty: The negative reward value to apply.
|
penalty: The negative reward value to apply.
|
||||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||||
|
open_threshold: Normalized state below which the gripper is considered "open".
|
||||||
|
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
penalty: float = -0.01
|
penalty: float = -0.02
|
||||||
max_gripper_pos: float = 30.0
|
max_gripper_pos: float = 30.0
|
||||||
|
open_threshold: float = 0.1
|
||||||
|
closed_threshold: float = 0.9
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""
|
"""
|
||||||
@@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
|||||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||||
|
|
||||||
# Calculate penalty boolean as in original
|
# Calculate penalty boolean as in original
|
||||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
# - currently open AND target is closed -> close transition
|
||||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
# - currently closed AND target is open -> open transition
|
||||||
)
|
is_open = gripper_state_normalized < self.open_threshold
|
||||||
|
is_closed = gripper_state_normalized > self.closed_threshold
|
||||||
|
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||||
|
cmd_open = gripper_action_normalized < self.open_threshold
|
||||||
|
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||||
|
|
||||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||||
|
|
||||||
@@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
|||||||
Returns the configuration of the step for serialization.
|
Returns the configuration of the step for serialization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the penalty value and max gripper position.
|
A dictionary containing the penalty value, max gripper position,
|
||||||
|
and the open/closed thresholds.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"penalty": self.penalty,
|
"penalty": self.penalty,
|
||||||
"max_gripper_pos": self.max_gripper_pos,
|
"max_gripper_pos": self.max_gripper_pos,
|
||||||
|
"open_threshold": self.open_threshold,
|
||||||
|
"closed_threshold": self.closed_threshold,
|
||||||
}
|
}
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -551,6 +551,12 @@ def step_env_and_process_transition(
|
|||||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||||
|
|
||||||
|
if hasattr(env, "get_raw_joint_positions"):
|
||||||
|
raw_joint_positions = env.get_raw_joint_positions()
|
||||||
|
if raw_joint_positions is not None:
|
||||||
|
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||||
|
|
||||||
new_info = info.copy()
|
new_info = info.copy()
|
||||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||||
|
|
||||||
@@ -568,6 +574,24 @@ def step_env_and_process_transition(
|
|||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
|
||||||
|
def reset_and_build_transition(
|
||||||
|
env: gym.Env,
|
||||||
|
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
|
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
|
) -> EnvTransition:
|
||||||
|
"""Reset env + processors and return the first env-processed transition."""
|
||||||
|
obs, info = env.reset()
|
||||||
|
env_processor.reset()
|
||||||
|
action_processor.reset()
|
||||||
|
complementary_data: dict[str, Any] = {}
|
||||||
|
if hasattr(env, "get_raw_joint_positions"):
|
||||||
|
raw_joint_positions = env.get_raw_joint_positions()
|
||||||
|
if raw_joint_positions is not None:
|
||||||
|
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||||
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
|
return env_processor(data=transition)
|
||||||
|
|
||||||
|
|
||||||
def control_loop(
|
def control_loop(
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
@@ -593,17 +617,7 @@ def control_loop(
|
|||||||
print("- When not intervening, robot will stay still")
|
print("- When not intervening, robot will stay still")
|
||||||
print("- Press Ctrl+C to exit")
|
print("- Press Ctrl+C to exit")
|
||||||
|
|
||||||
# Reset environment and processors
|
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||||
obs, info = env.reset()
|
|
||||||
complementary_data = (
|
|
||||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
|
||||||
)
|
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
# Process initial observation
|
|
||||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
|
||||||
transition = env_processor(data=transition)
|
|
||||||
|
|
||||||
# Determine if gripper is used
|
# Determine if gripper is used
|
||||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||||
@@ -726,12 +740,7 @@ def control_loop(
|
|||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
# Reset for new episode
|
# Reset for new episode
|
||||||
obs, info = env.reset()
|
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
transition = create_transition(observation=obs, info=info)
|
|
||||||
transition = env_processor(transition)
|
|
||||||
|
|
||||||
# Maintain fps timing
|
# Maintain fps timing
|
||||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||||
|
|||||||
Reference in New Issue
Block a user