mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
change folder structure to reduce the size of gym_manip
This commit is contained in:
@@ -15,6 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .device_processor import DeviceProcessor
|
||||
from .hil_processor import (
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
TimeLimitProcessor,
|
||||
)
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import (
|
||||
ImageProcessor,
|
||||
@@ -36,15 +42,26 @@ from .pipeline import (
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
from .robot_processor import (
|
||||
InverseKinematicsProcessor,
|
||||
JointVelocityProcessor,
|
||||
MotorCurrentProcessor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessor",
|
||||
"IdentityProcessor",
|
||||
"ImageCropResizeProcessor",
|
||||
"ImageProcessor",
|
||||
"InfoProcessor",
|
||||
"InterventionActionProcessor",
|
||||
"InverseKinematicsProcessor",
|
||||
"JointVelocityProcessor",
|
||||
"MotorCurrentProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"ObservationProcessor",
|
||||
@@ -54,6 +71,7 @@ __all__ = [
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TimeLimitProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
class ImageCropResizeProcessor:
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]]
|
||||
resize_size: tuple[int, int] = (128, 128)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return transition
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
# Always resize
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor:
|
||||
"""Track episode time and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor:
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if complementary_data is None or action is None:
|
||||
return transition
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
|
||||
if current_gripper_pos is None:
|
||||
return transition
|
||||
|
||||
gripper_action = action[-1].item()
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
gripper_action_normalized = gripper_action - 1.0
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Add penalty information to complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
|
||||
# Create new transition with updated complementary data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor:
|
||||
"""Handle action intervention based on signals in the transition.
|
||||
|
||||
This processor checks for intervention signals in the transition's complementary data
|
||||
and overrides agent actions when intervention is active.
|
||||
"""
|
||||
|
||||
use_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get("teleop_action", {})
|
||||
is_intervention = complementary_data.get("is_intervention", False)
|
||||
terminate_episode = complementary_data.get("terminate_episode", False)
|
||||
success = complementary_data.get("success", False)
|
||||
rerecord_episode = complementary_data.get("rerecord_episode", False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action:
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get("delta_x", 0.0),
|
||||
teleop_action.get("delta_y", 0.0),
|
||||
teleop_action.get("delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["is_intervention"] = is_intervention
|
||||
info["rerecord_episode"] = rerecord_episode
|
||||
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,245 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations.
|
||||
|
||||
Computes joint velocities by tracking changes in joint positions over time.
|
||||
"""
|
||||
|
||||
joint_velocity_limits: float = 100.0
|
||||
dt: float = 1.0 / 10
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
if current_positions is None:
|
||||
return transition
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
# Return new transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"joint_velocity_limits": self.joint_velocity_limits,
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessor:
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
env: gym.Env = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
# Get current values from complementary_data (where robot state would be stored)
|
||||
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
if current_state is None:
|
||||
return transition
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
# Return new transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_processor")
|
||||
class InverseKinematicsProcessor:
|
||||
"""Convert end-effector space actions to joint space using inverse kinematics.
|
||||
|
||||
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
|
||||
to joint space commands using forward and inverse kinematics. It maintains the current
|
||||
end-effector pose and joint positions to compute the transformations.
|
||||
"""
|
||||
|
||||
urdf_path: str
|
||||
target_frame_name: str = "gripper_link"
|
||||
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
# State tracking
|
||||
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the kinematics module after dataclass initialization."""
|
||||
if self.urdf_path:
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.urdf_path,
|
||||
target_frame_name=self.target_frame_name,
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
action_np = action.detach().cpu().numpy().squeeze()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
current_gripper_pos = raw_joint_positions[-1]
|
||||
if self.current_joint_pos is None:
|
||||
self.current_joint_pos = raw_joint_positions
|
||||
|
||||
# Initialize end-effector position if not available
|
||||
if self.current_joint_pos is None:
|
||||
return transition # Cannot proceed without joint positions
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Scale deltas by step sizes
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action_np[0] * self.end_effector_step_sizes["x"],
|
||||
action_np[1] * self.end_effector_step_sizes["y"],
|
||||
action_np[2] * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
|
||||
|
||||
# Update current state
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values.copy()
|
||||
|
||||
# Create new action with joint space commands
|
||||
gripper_action = current_gripper_pos
|
||||
if len(action_np) > 3:
|
||||
# Handle gripper command separately
|
||||
gripper_command = action_np[3]
|
||||
|
||||
# Process gripper command (convert from [0,2] to delta) and discretize
|
||||
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
|
||||
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
|
||||
|
||||
# Combine joint positions and gripper
|
||||
target_joint_values[-1] = gripper_action
|
||||
|
||||
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = converted_action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"urdf_path": self.urdf_path,
|
||||
"target_frame_name": self.target_frame_name,
|
||||
"end_effector_step_sizes": self.end_effector_step_sizes,
|
||||
"end_effector_bounds": self.end_effector_bounds,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -16,27 +16,34 @@
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
ImageProcessor,
|
||||
RobotProcessor,
|
||||
StateProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.hil_processor import (
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
TimeLimitProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor.robot_processor import (
|
||||
InverseKinematicsProcessor,
|
||||
JointVelocityProcessor,
|
||||
MotorCurrentProcessor,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
@@ -279,477 +286,6 @@ class RobotEnv(gym.Env):
|
||||
self.robot.disconnect()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor:
|
||||
"""Add joint velocity information to observations.
|
||||
|
||||
Computes joint velocities by tracking changes in joint positions over time.
|
||||
"""
|
||||
|
||||
joint_velocity_limits: float = 100.0
|
||||
dt: float = 1.0 / 10
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
if current_positions is None:
|
||||
return transition
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
# Return new transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"joint_velocity_limits": self.joint_velocity_limits,
|
||||
"fps": self.fps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessor:
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
env: gym.Env = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
# Get current values from complementary_data (where robot state would be stored)
|
||||
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
if current_state is None:
|
||||
return transition
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
|
||||
# Return new transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
class ImageCropResizeProcessor:
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]]
|
||||
resize_size: tuple[int, int] = (128, 128)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return transition
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
# Always resize
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor:
|
||||
"""Track episode time and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is None:
|
||||
return transition
|
||||
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor:
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
if complementary_data is None or action is None:
|
||||
return transition
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
|
||||
if current_gripper_pos is None:
|
||||
return transition
|
||||
|
||||
gripper_action = action[-1].item()
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
gripper_action_normalized = gripper_action - 1.0
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Add penalty information to complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
|
||||
# Create new transition with updated complementary data
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor:
|
||||
"""Handle action intervention based on signals in the transition.
|
||||
|
||||
This processor checks for intervention signals in the transition's complementary data
|
||||
and overrides agent actions when intervention is active.
|
||||
"""
|
||||
|
||||
use_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get("teleop_action", {})
|
||||
is_intervention = complementary_data.get("is_intervention", False)
|
||||
terminate_episode = complementary_data.get("terminate_episode", False)
|
||||
success = complementary_data.get("success", False)
|
||||
rerecord_episode = complementary_data.get("rerecord_episode", False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action:
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get("delta_x", 0.0),
|
||||
teleop_action.get("delta_y", 0.0),
|
||||
teleop_action.get("delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["is_intervention"] = is_intervention
|
||||
info["rerecord_episode"] = rerecord_episode
|
||||
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_processor")
|
||||
class InverseKinematicsProcessor:
|
||||
"""Convert end-effector space actions to joint space using inverse kinematics.
|
||||
|
||||
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
|
||||
to joint space commands using forward and inverse kinematics. It maintains the current
|
||||
end-effector pose and joint positions to compute the transformations.
|
||||
"""
|
||||
|
||||
urdf_path: str
|
||||
target_frame_name: str = "gripper_link"
|
||||
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
max_gripper_pos: float = 30.0
|
||||
env: gym.Env = None # Environment reference to get current state
|
||||
|
||||
# State tracking
|
||||
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the kinematics module after dataclass initialization."""
|
||||
if self.urdf_path:
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.urdf_path,
|
||||
target_frame_name=self.target_frame_name,
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
action_np = action.detach().cpu().numpy().squeeze()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
current_gripper_pos = raw_joint_positions[-1]
|
||||
if self.current_joint_pos is None:
|
||||
self.current_joint_pos = raw_joint_positions
|
||||
|
||||
# Initialize end-effector position if not available
|
||||
if self.current_joint_pos is None:
|
||||
return transition # Cannot proceed without joint positions
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Scale deltas by step sizes
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action_np[0] * self.end_effector_step_sizes["x"],
|
||||
action_np[1] * self.end_effector_step_sizes["y"],
|
||||
action_np[2] * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
|
||||
|
||||
# Update current state
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values.copy()
|
||||
|
||||
# Create new action with joint space commands
|
||||
gripper_action = current_gripper_pos
|
||||
if len(action_np) > 3:
|
||||
# Handle gripper command separately
|
||||
gripper_command = action_np[3]
|
||||
|
||||
# Process gripper command (convert from [0,2] to delta) and discretize
|
||||
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
|
||||
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
|
||||
|
||||
# Combine joint positions and gripper
|
||||
target_joint_values[-1] = gripper_action
|
||||
|
||||
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = converted_action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"urdf_path": self.urdf_path,
|
||||
"target_frame_name": self.target_frame_name,
|
||||
"end_effector_step_sizes": self.end_effector_step_sizes,
|
||||
"end_effector_bounds": self.end_effector_bounds,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
|
||||
"""
|
||||
Factory function to create a robot environment.
|
||||
@@ -821,7 +357,6 @@ def make_processors(env, cfg):
|
||||
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
||||
end_effector_bounds=cfg.processor.end_effector_bounds,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||
env=env,
|
||||
),
|
||||
]
|
||||
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
||||
|
||||
Reference in New Issue
Block a user