mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873)
* fix(RL) add missing config arguments * respond to copilot review * fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic. --------- Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
@@ -205,6 +205,7 @@ class ObservationConfig:
|
|||||||
|
|
||||||
add_joint_velocity_to_observation: bool = False
|
add_joint_velocity_to_observation: bool = False
|
||||||
add_current_to_observation: bool = False
|
add_current_to_observation: bool = False
|
||||||
|
add_ee_pose_to_observation: bool = False
|
||||||
display_cameras: bool = False
|
display_cameras: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||||
"""
|
"""
|
||||||
Applies a penalty for inefficient gripper usage.
|
Applies a penalty for inefficient gripper usage.
|
||||||
|
|
||||||
@@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
|||||||
penalty: float = -0.01
|
penalty: float = -0.01
|
||||||
max_gripper_pos: float = 30.0
|
max_gripper_pos: float = 30.0
|
||||||
|
|
||||||
def complementary_data(self, complementary_data: dict) -> dict:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""
|
"""
|
||||||
Calculates the gripper penalty and adds it to the complementary data.
|
Calculates the gripper penalty and adds it to the complementary data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
complementary_data: The incoming complementary data, which should contain
|
transition: The incoming environment transition.
|
||||||
raw joint positions.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
The modified transition with the penalty added to complementary data.
|
||||||
"""
|
"""
|
||||||
action = self.transition.get(TransitionKey.ACTION)
|
new_transition = transition.copy()
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
|
||||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||||
if raw_joint_positions is None:
|
if raw_joint_positions is None:
|
||||||
return complementary_data
|
return new_transition
|
||||||
|
|
||||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||||
if current_gripper_pos is None:
|
if current_gripper_pos is None:
|
||||||
return complementary_data
|
return new_transition
|
||||||
|
|
||||||
# Gripper action is a PolicyAction at this stage
|
# Gripper action is a PolicyAction at this stage
|
||||||
gripper_action = action[-1].item()
|
gripper_action = action[-1].item()
|
||||||
@@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
|||||||
|
|
||||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||||
|
|
||||||
# Create new complementary data with penalty info
|
# Update complementary data with penalty info
|
||||||
new_complementary_data = dict(complementary_data)
|
new_complementary_data = dict(complementary_data)
|
||||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||||
|
|
||||||
return new_complementary_data
|
return new_transition
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
def get_config(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -412,7 +412,10 @@ def make_processors(
|
|||||||
if cfg.processor.observation.add_current_to_observation:
|
if cfg.processor.observation.add_current_to_observation:
|
||||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||||
|
|
||||||
if kinematics_solver is not None:
|
add_ee_pose = (
|
||||||
|
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||||
|
)
|
||||||
|
if kinematics_solver is not None and add_ee_pose:
|
||||||
env_pipeline_steps.append(
|
env_pipeline_steps.append(
|
||||||
ForwardKinematicsJointsToEEObservation(
|
ForwardKinematicsJointsToEEObservation(
|
||||||
kinematics=kinematics_solver,
|
kinematics=kinematics_solver,
|
||||||
@@ -435,7 +438,12 @@ def make_processors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add gripper penalty processor if gripper config exists and enabled
|
# Add gripper penalty processor if gripper config exists and enabled
|
||||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||||
|
if (
|
||||||
|
cfg.processor.gripper is not None
|
||||||
|
and cfg.processor.gripper.use_gripper
|
||||||
|
and cfg.processor.max_gripper_pos is not None
|
||||||
|
):
|
||||||
env_pipeline_steps.append(
|
env_pipeline_steps.append(
|
||||||
GripperPenaltyProcessorStep(
|
GripperPenaltyProcessorStep(
|
||||||
penalty=cfg.processor.gripper.gripper_penalty,
|
penalty=cfg.processor.gripper.gripper_penalty,
|
||||||
|
|||||||
Reference in New Issue
Block a user