mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ef8bfffbd7 | |||
| f887ab3f6a | |||
| c2556439e5 | |||
| d2a046dfc5 | |||
| 613d581f6c | |||
| 58b6d844c4 | |||
| 30e1886b64 | |||
| 9c9064e5be | |||
| 494f469a2b | |||
| cd105f65cb | |||
| 9c2af818ff | |||
| 6495bb9706 |
@@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
This step normalizes the `transition` object by:
|
This step normalizes the `transition` object by:
|
||||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||||
|
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
if TELEOP_ACTION_KEY in info:
|
if TELEOP_ACTION_KEY in info:
|
||||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||||
|
|
||||||
|
if DISCRETE_PENALTY_KEY in info:
|
||||||
|
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||||
|
|
||||||
if "is_intervention" in info:
|
if "is_intervention" in info:
|
||||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||||
|
|
||||||
@@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
|||||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||||
"""
|
"""
|
||||||
Applies a penalty for inefficient gripper usage.
|
Applies a small per-transition cost on the discrete gripper action.
|
||||||
|
|
||||||
This step penalizes actions that attempt to close an already closed gripper or
|
Fires only when the commanded action would actually transition the gripper
|
||||||
open an already open one, based on position thresholds.
|
from one extreme to the other (close-while-open or open-while-closed).
|
||||||
|
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||||
|
commands unpenalized.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
penalty: The negative reward value to apply.
|
penalty: The negative reward value to apply.
|
||||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||||
|
open_threshold: Normalized state below which the gripper is considered "open".
|
||||||
|
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
penalty: float = -0.01
|
penalty: float = -0.02
|
||||||
max_gripper_pos: float = 30.0
|
max_gripper_pos: float = 30.0
|
||||||
|
open_threshold: float = 0.1
|
||||||
|
closed_threshold: float = 0.9
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""
|
"""
|
||||||
@@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
|||||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||||
|
|
||||||
# Calculate penalty boolean as in original
|
# Calculate penalty boolean as in original
|
||||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
# - currently open AND target is closed -> close transition
|
||||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
# - currently closed AND target is open -> open transition
|
||||||
)
|
is_open = gripper_state_normalized < self.open_threshold
|
||||||
|
is_closed = gripper_state_normalized > self.closed_threshold
|
||||||
|
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||||
|
cmd_open = gripper_action_normalized < self.open_threshold
|
||||||
|
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||||
|
|
||||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||||
|
|
||||||
@@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
|||||||
Returns the configuration of the step for serialization.
|
Returns the configuration of the step for serialization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the penalty value and max gripper position.
|
A dictionary containing the penalty value, max gripper position,
|
||||||
|
and the open/closed thresholds.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"penalty": self.penalty,
|
"penalty": self.penalty,
|
||||||
"max_gripper_pos": self.max_gripper_pos,
|
"max_gripper_pos": self.max_gripper_pos,
|
||||||
|
"open_threshold": self.open_threshold,
|
||||||
|
"closed_threshold": self.closed_threshold,
|
||||||
}
|
}
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@@ -134,6 +134,15 @@ class _NormalizationMixin:
|
|||||||
if self.dtype is None:
|
if self.dtype is None:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||||
|
self._reshape_visual_stats()
|
||||||
|
|
||||||
|
def _reshape_visual_stats(self) -> None:
|
||||||
|
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||||
|
for key, feature in self.features.items():
|
||||||
|
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||||
|
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||||
|
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||||
|
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||||
@@ -152,6 +161,7 @@ class _NormalizationMixin:
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||||
|
self._reshape_visual_stats()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
@@ -201,6 +211,7 @@ class _NormalizationMixin:
|
|||||||
# Don't load from state_dict, keep the explicitly provided stats
|
# Don't load from state_dict, keep the explicitly provided stats
|
||||||
# But ensure _tensor_stats is properly initialized
|
# But ensure _tensor_stats is properly initialized
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||||
|
self._reshape_visual_stats()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Normal behavior: load stats from state_dict
|
# Normal behavior: load stats from state_dict
|
||||||
@@ -211,6 +222,7 @@ class _NormalizationMixin:
|
|||||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||||
dtype=torch.float32, device=self.device
|
dtype=torch.float32, device=self.device
|
||||||
)
|
)
|
||||||
|
self._reshape_visual_stats()
|
||||||
|
|
||||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||||
# and other functions that rely on self.stats
|
# and other functions that rely on self.stats
|
||||||
|
|||||||
+29
-19
@@ -60,7 +60,7 @@ from torch.multiprocessing import Event, Queue
|
|||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||||
from lerobot.policies import make_policy
|
from lerobot.policies import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
@@ -89,9 +89,9 @@ from lerobot.utils.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .gym_manipulator import (
|
from .gym_manipulator import (
|
||||||
create_transition,
|
|
||||||
make_processors,
|
make_processors,
|
||||||
make_robot_env,
|
make_robot_env,
|
||||||
|
reset_and_build_transition,
|
||||||
step_env_and_process_transition,
|
step_env_and_process_transition,
|
||||||
)
|
)
|
||||||
from .process import ProcessSignalHandler
|
from .process import ProcessSignalHandler
|
||||||
@@ -261,13 +261,12 @@ def act_with_policy(
|
|||||||
policy = policy.eval()
|
policy = policy.eval()
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
obs, info = online_env.reset()
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
env_processor.reset()
|
policy_cfg=cfg.policy,
|
||||||
action_processor.reset()
|
dataset_stats=cfg.policy.dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
# Process initial observation
|
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||||
transition = create_transition(observation=obs, info=info)
|
|
||||||
transition = env_processor(transition)
|
|
||||||
|
|
||||||
# NOTE: For the moment we will solely handle the case of a single environment
|
# NOTE: For the moment we will solely handle the case of a single environment
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
@@ -291,8 +290,21 @@ def act_with_policy(
|
|||||||
|
|
||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with policy_timer:
|
with policy_timer:
|
||||||
# Extract observation from transition for policy
|
normalized_observation = preprocessor.process_observation(observation)
|
||||||
action = policy.select_action(batch=observation)
|
action = policy.select_action(batch=normalized_observation)
|
||||||
|
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
||||||
|
# `select_action` concatenates an argmax index in env space at the last dim;
|
||||||
|
# action stats cover the continuous dims only, so feeding the full vector to
|
||||||
|
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
||||||
|
# index by treating it as a normalized value.
|
||||||
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
|
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||||
|
discrete_action = action[..., -1:].to(
|
||||||
|
device=continuous_action.device, dtype=continuous_action.dtype
|
||||||
|
)
|
||||||
|
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||||
|
else:
|
||||||
|
action = postprocessor.process_action(action)
|
||||||
policy_fps = policy_timer.fps_last
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
@@ -326,7 +338,8 @@ def act_with_policy(
|
|||||||
|
|
||||||
# Check for intervention from transition info
|
# Check for intervention from transition info
|
||||||
intervention_info = new_transition[TransitionKey.INFO]
|
intervention_info = new_transition[TransitionKey.INFO]
|
||||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
|
||||||
|
if is_intervention:
|
||||||
episode_intervention = True
|
episode_intervention = True
|
||||||
episode_intervention_steps += 1
|
episode_intervention_steps += 1
|
||||||
|
|
||||||
@@ -334,6 +347,10 @@ def act_with_policy(
|
|||||||
"discrete_penalty": torch.tensor(
|
"discrete_penalty": torch.tensor(
|
||||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||||
),
|
),
|
||||||
|
# Forward the intervention flag so the learner can route this transition
|
||||||
|
# into the offline replay buffer (see `process_transitions` in learner.py).
|
||||||
|
# Use the plain string key so the payload survives torch.load(weights_only=True).
|
||||||
|
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||||
}
|
}
|
||||||
# Create transition for learner (convert to old format)
|
# Create transition for learner (convert to old format)
|
||||||
list_transition_to_send_to_learner.append(
|
list_transition_to_send_to_learner.append(
|
||||||
@@ -390,14 +407,7 @@ def act_with_policy(
|
|||||||
episode_intervention_steps = 0
|
episode_intervention_steps = 0
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
|
|
||||||
# Reset environment and processors
|
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||||
obs, info = online_env.reset()
|
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
# Process initial observation
|
|
||||||
transition = create_transition(observation=obs, info=info)
|
|
||||||
transition = env_processor(transition)
|
|
||||||
|
|
||||||
if cfg.env.fps is not None:
|
if cfg.env.fps is not None:
|
||||||
dt_time = time.perf_counter() - start_time
|
dt_time = time.perf_counter() - start_time
|
||||||
|
|||||||
@@ -383,10 +383,21 @@ def make_processors(
|
|||||||
GymHILAdapterProcessorStep(),
|
GymHILAdapterProcessorStep(),
|
||||||
Numpy2TorchActionProcessorStep(),
|
Numpy2TorchActionProcessorStep(),
|
||||||
VanillaObservationProcessorStep(),
|
VanillaObservationProcessorStep(),
|
||||||
AddBatchDimensionProcessorStep(),
|
|
||||||
DeviceProcessorStep(device=device),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Add time limit processor if reset config exists
|
||||||
|
if cfg.processor.reset is not None:
|
||||||
|
env_pipeline_steps.append(
|
||||||
|
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||||
|
)
|
||||||
|
|
||||||
|
env_pipeline_steps.extend(
|
||||||
|
[
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return DataProcessorPipeline(
|
return DataProcessorPipeline(
|
||||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||||
), DataProcessorPipeline(
|
), DataProcessorPipeline(
|
||||||
@@ -551,8 +562,19 @@ def step_env_and_process_transition(
|
|||||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||||
|
|
||||||
|
if hasattr(env, "get_raw_joint_positions"):
|
||||||
|
raw_joint_positions = env.get_raw_joint_positions()
|
||||||
|
if raw_joint_positions is not None:
|
||||||
|
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||||
|
|
||||||
|
# Merge env and action-processor info: env wins for str keys, action-processor
|
||||||
|
# wins for `TeleopEvents` enum keys
|
||||||
|
action_info = processed_action_transition[TransitionKey.INFO]
|
||||||
new_info = info.copy()
|
new_info = info.copy()
|
||||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
for key, value in action_info.items():
|
||||||
|
if isinstance(key, TeleopEvents):
|
||||||
|
new_info[key] = value
|
||||||
|
|
||||||
new_transition = create_transition(
|
new_transition = create_transition(
|
||||||
observation=obs,
|
observation=obs,
|
||||||
@@ -568,6 +590,24 @@ def step_env_and_process_transition(
|
|||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
|
||||||
|
def reset_and_build_transition(
|
||||||
|
env: gym.Env,
|
||||||
|
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
|
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
|
) -> EnvTransition:
|
||||||
|
"""Reset env + processors and return the first env-processed transition."""
|
||||||
|
obs, info = env.reset()
|
||||||
|
env_processor.reset()
|
||||||
|
action_processor.reset()
|
||||||
|
complementary_data: dict[str, Any] = {}
|
||||||
|
if hasattr(env, "get_raw_joint_positions"):
|
||||||
|
raw_joint_positions = env.get_raw_joint_positions()
|
||||||
|
if raw_joint_positions is not None:
|
||||||
|
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||||
|
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||||
|
return env_processor(data=transition)
|
||||||
|
|
||||||
|
|
||||||
def control_loop(
|
def control_loop(
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||||
@@ -593,17 +633,7 @@ def control_loop(
|
|||||||
print("- When not intervening, robot will stay still")
|
print("- When not intervening, robot will stay still")
|
||||||
print("- Press Ctrl+C to exit")
|
print("- Press Ctrl+C to exit")
|
||||||
|
|
||||||
# Reset environment and processors
|
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||||
obs, info = env.reset()
|
|
||||||
complementary_data = (
|
|
||||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
|
||||||
)
|
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
# Process initial observation
|
|
||||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
|
||||||
transition = env_processor(data=transition)
|
|
||||||
|
|
||||||
# Determine if gripper is used
|
# Determine if gripper is used
|
||||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||||
@@ -665,7 +695,7 @@ def control_loop(
|
|||||||
# Create a neutral action (no movement)
|
# Create a neutral action (no movement)
|
||||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||||
if use_gripper:
|
if use_gripper:
|
||||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||||
|
|
||||||
# Use the new step function
|
# Use the new step function
|
||||||
transition = step_env_and_process_transition(
|
transition = step_env_and_process_transition(
|
||||||
@@ -723,12 +753,7 @@ def control_loop(
|
|||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
# Reset for new episode
|
# Reset for new episode
|
||||||
obs, info = env.reset()
|
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||||
env_processor.reset()
|
|
||||||
action_processor.reset()
|
|
||||||
|
|
||||||
transition = create_transition(observation=obs, info=info)
|
|
||||||
transition = env_processor(transition)
|
|
||||||
|
|
||||||
# Maintain fps timing
|
# Maintain fps timing
|
||||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from lerobot.common.wandb_utils import WandBLogger
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||||
from lerobot.policies import make_policy
|
from lerobot.policies import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||||
@@ -317,6 +317,11 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
preprocessor, _postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg.policy,
|
||||||
|
dataset_stats=cfg.policy.dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
|
|
||||||
last_time_policy_pushed = time.time()
|
last_time_policy_pushed = time.time()
|
||||||
@@ -405,8 +410,8 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
actions = batch[ACTION]
|
actions = batch[ACTION]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
observations = batch["state"]
|
observations = preprocessor.process_observation(batch["state"])
|
||||||
next_observations = batch["next_state"]
|
next_observations = preprocessor.process_observation(batch["next_state"])
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
@@ -463,8 +468,8 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
actions = batch[ACTION]
|
actions = batch[ACTION]
|
||||||
rewards = batch["reward"]
|
rewards = batch["reward"]
|
||||||
observations = batch["state"]
|
observations = preprocessor.process_observation(batch["state"])
|
||||||
next_observations = batch["next_state"]
|
next_observations = preprocessor.process_observation(batch["next_state"])
|
||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
|
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
@@ -1163,7 +1168,7 @@ def process_transitions(
|
|||||||
|
|
||||||
# Add to offline buffer if it's an intervention
|
# Add to offline buffer if it's an intervention
|
||||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||||
TeleopEvents.IS_INTERVENTION
|
TeleopEvents.IS_INTERVENTION.value
|
||||||
):
|
):
|
||||||
offline_replay_buffer.add(**transition)
|
offline_replay_buffer.add(**transition)
|
||||||
|
|
||||||
|
|||||||
@@ -353,7 +353,8 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
|||||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||||
clip_min: The minimum allowed gripper joint position.
|
clip_min: The minimum allowed gripper joint position.
|
||||||
clip_max: The maximum allowed gripper joint position.
|
clip_max: The maximum allowed gripper joint position.
|
||||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
discrete_gripper: If True, interpret the input as a discrete class index
|
||||||
|
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
speed_factor: float = 20.0
|
speed_factor: float = 20.0
|
||||||
@@ -377,10 +378,10 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
|||||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||||
|
|
||||||
if self.discrete_gripper:
|
if self.discrete_gripper:
|
||||||
# Discrete gripper actions are in [0, 1, 2]
|
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
|
||||||
# 0: open, 1: close, 2: stay
|
# Negation accounts for SO100 sign (joint position increases on close).
|
||||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
|
||||||
gripper_vel = (gripper_vel - 1) * self.clip_max
|
gripper_vel = -(gripper_vel - 1) * self.clip_max
|
||||||
|
|
||||||
# Compute desired gripper position
|
# Compute desired gripper position
|
||||||
delta = gripper_vel * float(self.speed_factor)
|
delta = gripper_vel * float(self.speed_factor)
|
||||||
|
|||||||
Reference in New Issue
Block a user