mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ef8bfffbd7 | |||
| f887ab3f6a | |||
| c2556439e5 | |||
| d2a046dfc5 | |||
| 613d581f6c | |||
| 58b6d844c4 | |||
| 30e1886b64 | |||
| 9c9064e5be | |||
| 494f469a2b | |||
| cd105f65cb | |||
| 9c2af818ff | |||
| 6495bb9706 |
@@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if DISCRETE_PENALTY_KEY in info:
|
||||
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
@@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
Applies a small per-transition cost on the discrete gripper action.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
Fires only when the commanded action would actually transition the gripper
|
||||
from one extreme to the other (close-while-open or open-while-closed).
|
||||
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||
commands unpenalized.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
open_threshold: Normalized state below which the gripper is considered "open".
|
||||
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
penalty: float = -0.02
|
||||
max_gripper_pos: float = 30.0
|
||||
open_threshold: float = 0.1
|
||||
closed_threshold: float = 0.9
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
# - currently open AND target is closed -> close transition
|
||||
# - currently closed AND target is open -> open transition
|
||||
is_open = gripper_state_normalized < self.open_threshold
|
||||
is_closed = gripper_state_normalized > self.closed_threshold
|
||||
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||
cmd_open = gripper_action_normalized < self.open_threshold
|
||||
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
@@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
A dictionary containing the penalty value, max gripper position,
|
||||
and the open/closed thresholds.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
"open_threshold": self.open_threshold,
|
||||
"closed_threshold": self.closed_threshold,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -134,6 +134,15 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -152,6 +161,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -201,6 +211,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -211,6 +222,7 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
+29
-19
@@ -60,7 +60,7 @@ from torch.multiprocessing import Event, Queue
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -89,9 +89,9 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
reset_and_build_transition,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .process import ProcessSignalHandler
|
||||
@@ -261,13 +261,12 @@ def act_with_policy(
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -291,8 +290,21 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
||||
# `select_action` concatenates an argmax index in env space at the last dim;
|
||||
# action stats cover the continuous dims only, so feeding the full vector to
|
||||
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
||||
# index by treating it as a normalized value.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
device=continuous_action.device, dtype=continuous_action.dtype
|
||||
)
|
||||
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||
else:
|
||||
action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -326,7 +338,8 @@ def act_with_policy(
|
||||
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
|
||||
if is_intervention:
|
||||
episode_intervention = True
|
||||
episode_intervention_steps += 1
|
||||
|
||||
@@ -334,6 +347,10 @@ def act_with_policy(
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
# Forward the intervention flag so the learner can route this transition
|
||||
# into the offline replay buffer (see `process_transitions` in learner.py).
|
||||
# Use the plain string key so the payload survives torch.load(weights_only=True).
|
||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
@@ -390,14 +407,7 @@ def act_with_policy(
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
@@ -383,10 +383,21 @@ def make_processors(
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
|
||||
# Add time limit processor if reset config exists
|
||||
if cfg.processor.reset is not None:
|
||||
env_pipeline_steps.append(
|
||||
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
)
|
||||
|
||||
env_pipeline_steps.extend(
|
||||
[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
)
|
||||
|
||||
return DataProcessorPipeline(
|
||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
), DataProcessorPipeline(
|
||||
@@ -551,8 +562,19 @@ def step_env_and_process_transition(
|
||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
|
||||
# Merge env and action-processor info: env wins for str keys, action-processor
|
||||
# wins for `TeleopEvents` enum keys
|
||||
action_info = processed_action_transition[TransitionKey.INFO]
|
||||
new_info = info.copy()
|
||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||
for key, value in action_info.items():
|
||||
if isinstance(key, TeleopEvents):
|
||||
new_info[key] = value
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
@@ -568,6 +590,24 @@ def step_env_and_process_transition(
|
||||
return new_transition
|
||||
|
||||
|
||||
def reset_and_build_transition(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
) -> EnvTransition:
|
||||
"""Reset env + processors and return the first env-processed transition."""
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
complementary_data: dict[str, Any] = {}
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
return env_processor(data=transition)
|
||||
|
||||
|
||||
def control_loop(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
@@ -593,17 +633,7 @@ def control_loop(
|
||||
print("- When not intervening, robot will stay still")
|
||||
print("- Press Ctrl+C to exit")
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = env.reset()
|
||||
complementary_data = (
|
||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
||||
)
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(data=transition)
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Determine if gripper is used
|
||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||
@@ -665,7 +695,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -723,12 +753,7 @@ def control_loop(
|
||||
dataset.save_episode()
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Maintain fps timing
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
@@ -70,7 +70,7 @@ from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -317,6 +317,11 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
preprocessor, _postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
last_time_policy_pushed = time.time()
|
||||
@@ -405,8 +410,8 @@ def add_actor_information_and_train(
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
observations = preprocessor.process_observation(batch["state"])
|
||||
next_observations = preprocessor.process_observation(batch["next_state"])
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
@@ -463,8 +468,8 @@ def add_actor_information_and_train(
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
observations = preprocessor.process_observation(batch["state"])
|
||||
next_observations = preprocessor.process_observation(batch["next_state"])
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
@@ -1163,7 +1168,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
TeleopEvents.IS_INTERVENTION.value
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
|
||||
@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||
clip_min: The minimum allowed gripper joint position.
|
||||
clip_max: The maximum allowed gripper joint position.
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
discrete_gripper: If True, interpret the input as a discrete class index
|
||||
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
|
||||
"""
|
||||
|
||||
speed_factor: float = 20.0
|
||||
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_vel = (gripper_vel - 1) * self.clip_max
|
||||
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
|
||||
# Negation accounts for SO100 sign (joint position increases on close).
|
||||
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
|
||||
gripper_vel = -(gripper_vel - 1) * self.clip_max
|
||||
|
||||
# Compute desired gripper position
|
||||
delta = gripper_vel * float(self.speed_factor)
|
||||
|
||||
Reference in New Issue
Block a user