From b292dbbc557067c9701e844d3897600ae06beaca Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 4 Aug 2025 16:41:42 +0200 Subject: [PATCH] change folder structure to reduce the size of gym_manip --- src/lerobot/processor/__init__.py | 18 + src/lerobot/processor/hil_processor.py | 243 +++++++++++ src/lerobot/processor/robot_processor.py | 245 +++++++++++ src/lerobot/scripts/rl/gym_manipulator.py | 489 +--------------------- 4 files changed, 518 insertions(+), 477 deletions(-) create mode 100644 src/lerobot/processor/hil_processor.py create mode 100644 src/lerobot/processor/robot_processor.py diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 0a5a5dd2c..4effc220d 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -15,6 +15,12 @@ # limitations under the License. from .device_processor import DeviceProcessor +from .hil_processor import ( + GripperPenaltyProcessor, + ImageCropResizeProcessor, + InterventionActionProcessor, + TimeLimitProcessor, +) from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .observation_processor import ( ImageProcessor, @@ -36,15 +42,26 @@ from .pipeline import ( TruncatedProcessor, ) from .rename_processor import RenameProcessor +from .robot_processor import ( + InverseKinematicsProcessor, + JointVelocityProcessor, + MotorCurrentProcessor, +) __all__ = [ "ActionProcessor", "DeviceProcessor", "DoneProcessor", "EnvTransition", + "GripperPenaltyProcessor", "IdentityProcessor", + "ImageCropResizeProcessor", "ImageProcessor", "InfoProcessor", + "InterventionActionProcessor", + "InverseKinematicsProcessor", + "JointVelocityProcessor", + "MotorCurrentProcessor", "NormalizerProcessor", "UnnormalizerProcessor", "ObservationProcessor", @@ -54,6 +71,7 @@ __all__ = [ "RewardProcessor", "RobotProcessor", "StateProcessor", + "TimeLimitProcessor", "TransitionKey", "TruncatedProcessor", "VanillaObservationProcessor", diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py new file mode 100644 index 000000000..7790bc0bb --- /dev/null +++ b/src/lerobot/processor/hil_processor.py @@ -0,0 +1,243 @@ +from dataclasses import dataclass +from typing import Any + +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +@dataclass +@ProcessorStepRegistry.register("image_crop_resize_processor") +class ImageCropResizeProcessor: + """Crop and resize image observations.""" + + crop_params_dict: dict[str, tuple[int, int, int, int]] + resize_size: tuple[int, int] = (128, 128) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + if self.resize_size is None and not self.crop_params_dict: + return transition + + new_observation = dict(observation) + + # Process all image keys in the observation + for key in observation: + if "image" not in key: + continue + + image = observation[key] + device = image.device + if device.type == "mps": + image = image.cpu() + # Crop if crop params are provided for this key + if key in self.crop_params_dict: + crop_params = self.crop_params_dict[key] + image = F.crop(image, *crop_params) + # Always resize + image = F.resize(image, self.resize_size) + image = image.clamp(0.0, 1.0) + new_observation[key] = image.to(device) + + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = new_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "crop_params_dict": self.crop_params_dict, + "resize_size": self.resize_size, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register("time_limit_processor") +class TimeLimitProcessor: + """Track episode time and enforce time limits.""" + + max_episode_steps: int + current_step: int = 0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is None: + return transition + + self.current_step += 1 + if self.current_step >= self.max_episode_steps: + truncated = True + new_transition = transition.copy() + new_transition[TransitionKey.TRUNCATED] = truncated + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "max_episode_steps": self.max_episode_steps, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + self.current_step = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register("gripper_penalty_processor") +class GripperPenaltyProcessor: + penalty: float = -0.01 + max_gripper_pos: float = 30.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Calculate gripper penalty and add to complementary data.""" + action = transition.get(TransitionKey.ACTION) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + + if complementary_data is None or action is None: + return transition + + current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1] + if current_gripper_pos is None: + return transition + + gripper_action = action[-1].item() + gripper_action_normalized = gripper_action / self.max_gripper_pos + + # Normalize gripper state and action + gripper_state_normalized = current_gripper_pos / self.max_gripper_pos + gripper_action_normalized = gripper_action - 1.0 + + # Calculate penalty boolean as in original + gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 + ) + + gripper_penalty = self.penalty * int(gripper_penalty_bool) + + # Add penalty information to complementary data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + # Create new complementary data with penalty info + new_complementary_data = dict(complementary_data) + new_complementary_data["discrete_penalty"] = gripper_penalty + + # Create new transition with updated complementary data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "penalty": self.penalty, + "max_gripper_pos": self.max_gripper_pos, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + """Reset the processor state.""" + self.last_gripper_state = None + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register("intervention_action_processor") +class InterventionActionProcessor: + """Handle action intervention based on signals in the transition. + + This processor checks for intervention signals in the transition's complementary data + and overrides agent actions when intervention is active. + """ + + use_gripper: bool = False + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + + # Get intervention signals from complementary data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + teleop_action = complementary_data.get("teleop_action", {}) + is_intervention = complementary_data.get("is_intervention", False) + terminate_episode = complementary_data.get("terminate_episode", False) + success = complementary_data.get("success", False) + rerecord_episode = complementary_data.get("rerecord_episode", False) + + new_transition = transition.copy() + + # Override action if intervention is active + if is_intervention and teleop_action: + # Convert teleop_action dict to tensor format + action_list = [ + teleop_action.get("delta_x", 0.0), + teleop_action.get("delta_y", 0.0), + teleop_action.get("delta_z", 0.0), + ] + if self.use_gripper: + action_list.append(teleop_action.get("gripper", 1.0)) + + teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device) + new_transition[TransitionKey.ACTION] = teleop_action_tensor + + # Handle episode termination + new_transition[TransitionKey.DONE] = bool(terminate_episode) + new_transition[TransitionKey.REWARD] = float(success) + + # Update info with intervention metadata + info = new_transition.get(TransitionKey.INFO, {}) + info["is_intervention"] = is_intervention + info["rerecord_episode"] = rerecord_episode + info["next.success"] = success if terminate_episode else info.get("next.success", False) + new_transition[TransitionKey.INFO] = info + new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[ + TransitionKey.ACTION + ] + + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "use_gripper": self.use_gripper, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/robot_processor.py b/src/lerobot/processor/robot_processor.py new file mode 100644 index 000000000..f145c3b0a --- /dev/null +++ b/src/lerobot/processor/robot_processor.py @@ -0,0 +1,245 @@ +from dataclasses import dataclass, field +from typing import Any + +import gymnasium as gym +import numpy as np +import torch + +from lerobot.configs.types import PolicyFeature +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +@dataclass +@ProcessorStepRegistry.register("joint_velocity_processor") +class JointVelocityProcessor: + """Add joint velocity information to observations. + + Computes joint velocities by tracking changes in joint positions over time. + """ + + joint_velocity_limits: float = 100.0 + dt: float = 1.0 / 10 + + last_joint_positions: torch.Tensor | None = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + # Get current joint positions (assuming they're in observation.state) + current_positions = observation.get("observation.state") + if current_positions is None: + return transition + + # Initialize last joint positions if not already set + if self.last_joint_positions is None: + self.last_joint_positions = current_positions.clone() + + # Compute velocities + joint_velocities = (current_positions - self.last_joint_positions) / self.dt + self.last_joint_positions = current_positions.clone() + + # Extend observation with velocities + extended_state = torch.cat([current_positions, joint_velocities], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation["observation.state"] = extended_state + + # Return new transition + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = new_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "joint_velocity_limits": self.joint_velocity_limits, + "dt": self.dt, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + self.last_joint_positions = None + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register("current_processor") +class MotorCurrentProcessor: + """Add motor current information to observations.""" + + env: gym.Env = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + # Get current values from complementary_data (where robot state would be stored) + present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current") + motor_currents = torch.tensor( + [present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors], + dtype=torch.float32, + ).unsqueeze(0) + + current_state = observation.get("observation.state") + if current_state is None: + return transition + + extended_state = torch.cat([current_state, motor_currents], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation["observation.state"] = extended_state + + # Return new transition + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = new_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register("inverse_kinematics_processor") +class InverseKinematicsProcessor: + """Convert end-effector space actions to joint space using inverse kinematics. + + This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z) + to joint space commands using forward and inverse kinematics. It maintains the current + end-effector pose and joint positions to compute the transformations. + """ + + urdf_path: str + target_frame_name: str = "gripper_link" + end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0}) + end_effector_bounds: dict[str, list[float]] | None = None + max_gripper_pos: float = 30.0 + + # State tracking + current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False) + current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False) + kinematics: RobotKinematics | None = field(default=None, init=False, repr=False) + + def __post_init__(self): + """Initialize the kinematics module after dataclass initialization.""" + if self.urdf_path: + self.kinematics = RobotKinematics( + urdf_path=self.urdf_path, + target_frame_name=self.target_frame_name, + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + + action_np = action.detach().cpu().numpy().squeeze() + + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + raw_joint_positions = complementary_data.get("raw_joint_positions") + current_gripper_pos = raw_joint_positions[-1] + if self.current_joint_pos is None: + self.current_joint_pos = raw_joint_positions + + # Initialize end-effector position if not available + if self.current_joint_pos is None: + return transition # Cannot proceed without joint positions + + # Calculate current end-effector position using forward kinematics + if self.current_ee_pos is None: + self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) + + # Scale deltas by step sizes + delta_ee = np.array( + [ + action_np[0] * self.end_effector_step_sizes["x"], + action_np[1] * self.end_effector_step_sizes["y"], + action_np[2] * self.end_effector_step_sizes["z"], + ], + dtype=np.float32, + ) + + # Set desired end-effector position by adding delta + desired_ee_pos = np.eye(4) + desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation + + # Add delta to position and clip to bounds + desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee + if self.end_effector_bounds is not None: + desired_ee_pos[:3, 3] = np.clip( + desired_ee_pos[:3, 3], + self.end_effector_bounds["min"], + self.end_effector_bounds["max"], + ) + + # Compute inverse kinematics to get joint positions + target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos) + + # Update current state + self.current_ee_pos = desired_ee_pos.copy() + self.current_joint_pos = target_joint_values.copy() + + # Create new action with joint space commands + gripper_action = current_gripper_pos + if len(action_np) > 3: + # Handle gripper command separately + gripper_command = action_np[3] + + # Process gripper command (convert from [0,2] to delta) and discretize + gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos + gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos) + + # Combine joint positions and gripper + target_joint_values[-1] = gripper_action + + converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype) + + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = converted_action + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "urdf_path": self.urdf_path, + "target_frame_name": self.target_frame_name, + "end_effector_step_sizes": self.end_effector_step_sizes, + "end_effector_bounds": self.end_effector_bounds, + "max_gripper_pos": self.max_gripper_pos, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + """Reset the processor state.""" + self.current_ee_pos = None + self.current_joint_pos = None + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index cb3a97701..9673bb03e 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -16,27 +16,34 @@ import logging import time -from dataclasses import dataclass, field from typing import Any import gymnasium as gym import numpy as np import torch -import torchvision.transforms.functional as F # noqa: N812 from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.configs.types import PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.envs.configs import EnvConfig -from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( DeviceProcessor, ImageProcessor, RobotProcessor, StateProcessor, ) -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey +from lerobot.processor.hil_processor import ( + GripperPenaltyProcessor, + ImageCropResizeProcessor, + InterventionActionProcessor, + TimeLimitProcessor, +) +from lerobot.processor.pipeline import TransitionKey +from lerobot.processor.robot_processor import ( + InverseKinematicsProcessor, + JointVelocityProcessor, + MotorCurrentProcessor, +) from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, @@ -279,477 +286,6 @@ class RobotEnv(gym.Env): self.robot.disconnect() -@dataclass -@ProcessorStepRegistry.register("joint_velocity_processor") -class JointVelocityProcessor: - """Add joint velocity information to observations. - - Computes joint velocities by tracking changes in joint positions over time. - """ - - joint_velocity_limits: float = 100.0 - dt: float = 1.0 / 10 - - last_joint_positions: torch.Tensor | None = None - - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition - - # Get current joint positions (assuming they're in observation.state) - current_positions = observation.get("observation.state") - if current_positions is None: - return transition - - # Initialize last joint positions if not already set - if self.last_joint_positions is None: - self.last_joint_positions = current_positions.clone() - - # Compute velocities - joint_velocities = (current_positions - self.last_joint_positions) / self.dt - self.last_joint_positions = current_positions.clone() - - # Extend observation with velocities - extended_state = torch.cat([current_positions, joint_velocities], dim=-1) - - # Create new observation dict - new_observation = dict(observation) - new_observation["observation.state"] = extended_state - - # Return new transition - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = new_observation - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "joint_velocity_limits": self.joint_velocity_limits, - "fps": self.fps, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - self.last_joint_positions = None - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("current_processor") -class MotorCurrentProcessor: - """Add motor current information to observations.""" - - env: gym.Env = None - - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition - - # Get current values from complementary_data (where robot state would be stored) - present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current") - motor_currents = torch.tensor( - [present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors], - dtype=torch.float32, - ).unsqueeze(0) - - current_state = observation.get("observation.state") - if current_state is None: - return transition - - extended_state = torch.cat([current_state, motor_currents], dim=-1) - - # Create new observation dict - new_observation = dict(observation) - new_observation["observation.state"] = extended_state - - # Return new transition - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = new_observation - return new_transition - - def get_config(self) -> dict[str, Any]: - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("image_crop_resize_processor") -class ImageCropResizeProcessor: - """Crop and resize image observations.""" - - crop_params_dict: dict[str, tuple[int, int, int, int]] - resize_size: tuple[int, int] = (128, 128) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition - - if self.resize_size is None and not self.crop_params_dict: - return transition - - new_observation = dict(observation) - - # Process all image keys in the observation - for key in observation: - if "image" not in key: - continue - - image = observation[key] - device = image.device - if device.type == "mps": - image = image.cpu() - # Crop if crop params are provided for this key - if key in self.crop_params_dict: - crop_params = self.crop_params_dict[key] - image = F.crop(image, *crop_params) - # Always resize - image = F.resize(image, self.resize_size) - image = image.clamp(0.0, 1.0) - new_observation[key] = image.to(device) - - new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = new_observation - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "crop_params_dict": self.crop_params_dict, - "resize_size": self.resize_size, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("time_limit_processor") -class TimeLimitProcessor: - """Track episode time and enforce time limits.""" - - max_episode_steps: int - current_step: int = 0 - - def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is None: - return transition - - self.current_step += 1 - if self.current_step >= self.max_episode_steps: - truncated = True - new_transition = transition.copy() - new_transition[TransitionKey.TRUNCATED] = truncated - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "max_episode_steps": self.max_episode_steps, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - self.current_step = 0 - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("gripper_penalty_processor") -class GripperPenaltyProcessor: - penalty: float = -0.01 - max_gripper_pos: float = 30.0 - - def __call__(self, transition: EnvTransition) -> EnvTransition: - """Calculate gripper penalty and add to complementary data.""" - action = transition.get(TransitionKey.ACTION) - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - - if complementary_data is None or action is None: - return transition - - current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1] - if current_gripper_pos is None: - return transition - - gripper_action = action[-1].item() - gripper_action_normalized = gripper_action / self.max_gripper_pos - - # Normalize gripper state and action - gripper_state_normalized = current_gripper_pos / self.max_gripper_pos - gripper_action_normalized = gripper_action - 1.0 - - # Calculate penalty boolean as in original - gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( - gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 - ) - - gripper_penalty = self.penalty * int(gripper_penalty_bool) - - # Add penalty information to complementary data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - - # Create new complementary data with penalty info - new_complementary_data = dict(complementary_data) - new_complementary_data["discrete_penalty"] = gripper_penalty - - # Create new transition with updated complementary data - new_transition = transition.copy() - new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "penalty": self.penalty, - "max_gripper_pos": self.max_gripper_pos, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - """Reset the processor state.""" - self.last_gripper_state = None - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("intervention_action_processor") -class InterventionActionProcessor: - """Handle action intervention based on signals in the transition. - - This processor checks for intervention signals in the transition's complementary data - and overrides agent actions when intervention is active. - """ - - use_gripper: bool = False - - def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition.get(TransitionKey.ACTION) - if action is None: - return transition - - # Get intervention signals from complementary data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - teleop_action = complementary_data.get("teleop_action", {}) - is_intervention = complementary_data.get("is_intervention", False) - terminate_episode = complementary_data.get("terminate_episode", False) - success = complementary_data.get("success", False) - rerecord_episode = complementary_data.get("rerecord_episode", False) - - new_transition = transition.copy() - - # Override action if intervention is active - if is_intervention and teleop_action: - # Convert teleop_action dict to tensor format - action_list = [ - teleop_action.get("delta_x", 0.0), - teleop_action.get("delta_y", 0.0), - teleop_action.get("delta_z", 0.0), - ] - if self.use_gripper: - action_list.append(teleop_action.get("gripper", 1.0)) - - teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device) - new_transition[TransitionKey.ACTION] = teleop_action_tensor - - # Handle episode termination - new_transition[TransitionKey.DONE] = bool(terminate_episode) - new_transition[TransitionKey.REWARD] = float(success) - - # Update info with intervention metadata - info = new_transition.get(TransitionKey.INFO, {}) - info["is_intervention"] = is_intervention - info["rerecord_episode"] = rerecord_episode - info["next.success"] = success if terminate_episode else info.get("next.success", False) - new_transition[TransitionKey.INFO] = info - new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[ - TransitionKey.ACTION - ] - - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "use_gripper": self.use_gripper, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -@dataclass -@ProcessorStepRegistry.register("inverse_kinematics_processor") -class InverseKinematicsProcessor: - """Convert end-effector space actions to joint space using inverse kinematics. - - This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z) - to joint space commands using forward and inverse kinematics. It maintains the current - end-effector pose and joint positions to compute the transformations. - """ - - urdf_path: str - target_frame_name: str = "gripper_link" - end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0}) - end_effector_bounds: dict[str, list[float]] | None = None - max_gripper_pos: float = 30.0 - env: gym.Env = None # Environment reference to get current state - - # State tracking - current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False) - current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False) - kinematics: RobotKinematics | None = field(default=None, init=False, repr=False) - - def __post_init__(self): - """Initialize the kinematics module after dataclass initialization.""" - if self.urdf_path: - self.kinematics = RobotKinematics( - urdf_path=self.urdf_path, - target_frame_name=self.target_frame_name, - ) - - def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition.get(TransitionKey.ACTION) - if action is None: - return transition - - action_np = action.detach().cpu().numpy().squeeze() - - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - raw_joint_positions = complementary_data.get("raw_joint_positions") - current_gripper_pos = raw_joint_positions[-1] - if self.current_joint_pos is None: - self.current_joint_pos = raw_joint_positions - - # Initialize end-effector position if not available - if self.current_joint_pos is None: - return transition # Cannot proceed without joint positions - - # Calculate current end-effector position using forward kinematics - if self.current_ee_pos is None: - self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) - - # Scale deltas by step sizes - delta_ee = np.array( - [ - action_np[0] * self.end_effector_step_sizes["x"], - action_np[1] * self.end_effector_step_sizes["y"], - action_np[2] * self.end_effector_step_sizes["z"], - ], - dtype=np.float32, - ) - - # Set desired end-effector position by adding delta - desired_ee_pos = np.eye(4) - desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation - - # Add delta to position and clip to bounds - desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee - if self.end_effector_bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], - self.end_effector_bounds["min"], - self.end_effector_bounds["max"], - ) - - # Compute inverse kinematics to get joint positions - target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos) - - # Update current state - self.current_ee_pos = desired_ee_pos.copy() - self.current_joint_pos = target_joint_values.copy() - - # Create new action with joint space commands - gripper_action = current_gripper_pos - if len(action_np) > 3: - # Handle gripper command separately - gripper_command = action_np[3] - - # Process gripper command (convert from [0,2] to delta) and discretize - gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos - gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos) - - # Combine joint positions and gripper - target_joint_values[-1] = gripper_action - - converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype) - - new_transition = transition.copy() - new_transition[TransitionKey.ACTION] = converted_action - return new_transition - - def get_config(self) -> dict[str, Any]: - return { - "urdf_path": self.urdf_path, - "target_frame_name": self.target_frame_name, - "end_effector_step_sizes": self.end_effector_step_sizes, - "end_effector_bounds": self.end_effector_bounds, - "max_gripper_pos": self.max_gripper_pos, - } - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - """Reset the processor state.""" - self.current_ee_pos = None - self.current_joint_pos = None - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]: """ Factory function to create a robot environment. @@ -821,7 +357,6 @@ def make_processors(env, cfg): end_effector_step_sizes=cfg.processor.end_effector_step_sizes, end_effector_bounds=cfg.processor.end_effector_bounds, max_gripper_pos=cfg.processor.max_gripper_pos, - env=env, ), ] action_processor = RobotProcessor(steps=action_pipeline_steps)