change folder structure to reduce the size of gym_manip

This commit is contained in:
Michel Aractingi
2025-08-04 16:41:42 +02:00
parent f49280e89b
commit b292dbbc55
4 changed files with 518 additions and 477 deletions
+18
View File
@@ -15,6 +15,12 @@
# limitations under the License.
from .device_processor import DeviceProcessor
from .hil_processor import (
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
TimeLimitProcessor,
)
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import (
ImageProcessor,
@@ -36,15 +42,26 @@ from .pipeline import (
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
from .robot_processor import (
InverseKinematicsProcessor,
JointVelocityProcessor,
MotorCurrentProcessor,
)
__all__ = [
"ActionProcessor",
"DeviceProcessor",
"DoneProcessor",
"EnvTransition",
"GripperPenaltyProcessor",
"IdentityProcessor",
"ImageCropResizeProcessor",
"ImageProcessor",
"InfoProcessor",
"InterventionActionProcessor",
"InverseKinematicsProcessor",
"JointVelocityProcessor",
"MotorCurrentProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"ObservationProcessor",
@@ -54,6 +71,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TimeLimitProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
+243
View File
@@ -0,0 +1,243 @@
from dataclasses import dataclass
from typing import Any
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register("image_crop_resize_processor")
class ImageCropResizeProcessor:
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]]
resize_size: tuple[int, int] = (128, 128)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
if self.resize_size is None and not self.crop_params_dict:
return transition
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
# Always resize
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor:
"""Track episode time and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor:
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
if current_gripper_pos is None:
return transition
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
gripper_action_normalized = gripper_action - 1.0
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
# Create new transition with updated complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor:
"""Handle action intervention based on signals in the transition.
This processor checks for intervention signals in the transition's complementary data
and overrides agent actions when intervention is active.
"""
use_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get("teleop_action", {})
is_intervention = complementary_data.get("is_intervention", False)
terminate_episode = complementary_data.get("terminate_episode", False)
success = complementary_data.get("success", False)
rerecord_episode = complementary_data.get("rerecord_episode", False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action:
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("delta_x", 0.0),
teleop_action.get("delta_y", 0.0),
teleop_action.get("delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info["is_intervention"] = is_intervention
info["rerecord_episode"] = rerecord_episode
info["next.success"] = success if terminate_episode else info.get("next.success", False)
new_transition[TransitionKey.INFO] = info
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
TransitionKey.ACTION
]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+245
View File
@@ -0,0 +1,245 @@
from dataclasses import dataclass, field
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations.
Computes joint velocities by tracking changes in joint positions over time.
"""
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
last_joint_positions: torch.Tensor | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
return transition
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"dt": self.dt,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.last_joint_positions = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor:
"""Add motor current information to observations."""
env: gym.Env = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current values from complementary_data (where robot state would be stored)
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
if current_state is None:
return transition
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("inverse_kinematics_processor")
class InverseKinematicsProcessor:
"""Convert end-effector space actions to joint space using inverse kinematics.
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
to joint space commands using forward and inverse kinematics. It maintains the current
end-effector pose and joint positions to compute the transformations.
"""
urdf_path: str
target_frame_name: str = "gripper_link"
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
end_effector_bounds: dict[str, list[float]] | None = None
max_gripper_pos: float = 30.0
# State tracking
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the kinematics module after dataclass initialization."""
if self.urdf_path:
self.kinematics = RobotKinematics(
urdf_path=self.urdf_path,
target_frame_name=self.target_frame_name,
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
action_np = action.detach().cpu().numpy().squeeze()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
current_gripper_pos = raw_joint_positions[-1]
if self.current_joint_pos is None:
self.current_joint_pos = raw_joint_positions
# Initialize end-effector position if not available
if self.current_joint_pos is None:
return transition # Cannot proceed without joint positions
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Scale deltas by step sizes
delta_ee = np.array(
[
action_np[0] * self.end_effector_step_sizes["x"],
action_np[1] * self.end_effector_step_sizes["y"],
action_np[2] * self.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
# Update current state
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values.copy()
# Create new action with joint space commands
gripper_action = current_gripper_pos
if len(action_np) > 3:
# Handle gripper command separately
gripper_command = action_np[3]
# Process gripper command (convert from [0,2] to delta) and discretize
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
# Combine joint positions and gripper
target_joint_values[-1] = gripper_action
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = converted_action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"urdf_path": self.urdf_path,
"target_frame_name": self.target_frame_name,
"end_effector_step_sizes": self.end_effector_step_sizes,
"end_effector_bounds": self.end_effector_bounds,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.current_ee_pos = None
self.current_joint_pos = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
+12 -477
View File
@@ -16,27 +16,34 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import EnvConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
DeviceProcessor,
ImageProcessor,
RobotProcessor,
StateProcessor,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.hil_processor import (
GripperPenaltyProcessor,
ImageCropResizeProcessor,
InterventionActionProcessor,
TimeLimitProcessor,
)
from lerobot.processor.pipeline import TransitionKey
from lerobot.processor.robot_processor import (
InverseKinematicsProcessor,
JointVelocityProcessor,
MotorCurrentProcessor,
)
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
@@ -279,477 +286,6 @@ class RobotEnv(gym.Env):
self.robot.disconnect()
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessor:
"""Add joint velocity information to observations.
Computes joint velocities by tracking changes in joint positions over time.
"""
joint_velocity_limits: float = 100.0
dt: float = 1.0 / 10
last_joint_positions: torch.Tensor | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get("observation.state")
if current_positions is None:
return transition
# Initialize last joint positions if not already set
if self.last_joint_positions is None:
self.last_joint_positions = current_positions.clone()
# Compute velocities
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
self.last_joint_positions = current_positions.clone()
# Extend observation with velocities
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"joint_velocity_limits": self.joint_velocity_limits,
"fps": self.fps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.last_joint_positions = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessor:
"""Add motor current information to observations."""
env: gym.Env = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
# Get current values from complementary_data (where robot state would be stored)
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
motor_currents = torch.tensor(
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
dtype=torch.float32,
).unsqueeze(0)
current_state = observation.get("observation.state")
if current_state is None:
return transition
extended_state = torch.cat([current_state, motor_currents], dim=-1)
# Create new observation dict
new_observation = dict(observation)
new_observation["observation.state"] = extended_state
# Return new transition
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("image_crop_resize_processor")
class ImageCropResizeProcessor:
"""Crop and resize image observations."""
crop_params_dict: dict[str, tuple[int, int, int, int]]
resize_size: tuple[int, int] = (128, 128)
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
return transition
if self.resize_size is None and not self.crop_params_dict:
return transition
new_observation = dict(observation)
# Process all image keys in the observation
for key in observation:
if "image" not in key:
continue
image = observation[key]
device = image.device
if device.type == "mps":
image = image.cpu()
# Crop if crop params are provided for this key
if key in self.crop_params_dict:
crop_params = self.crop_params_dict[key]
image = F.crop(image, *crop_params)
# Always resize
image = F.resize(image, self.resize_size)
image = image.clamp(0.0, 1.0)
new_observation[key] = image.to(device)
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessor:
"""Track episode time and enforce time limits."""
max_episode_steps: int
current_step: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is None:
return transition
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
new_transition = transition.copy()
new_transition[TransitionKey.TRUNCATED] = truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"max_episode_steps": self.max_episode_steps,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
self.current_step = 0
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessor:
penalty: float = -0.01
max_gripper_pos: float = 30.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Calculate gripper penalty and add to complementary data."""
action = transition.get(TransitionKey.ACTION)
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or action is None:
return transition
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
if current_gripper_pos is None:
return transition
gripper_action = action[-1].item()
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
gripper_action_normalized = gripper_action - 1.0
# Calculate penalty boolean as in original
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
)
gripper_penalty = self.penalty * int(gripper_penalty_bool)
# Add penalty information to complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
# Create new transition with updated complementary data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessor:
"""Handle action intervention based on signals in the transition.
This processor checks for intervention signals in the transition's complementary data
and overrides agent actions when intervention is active.
"""
use_gripper: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
# Get intervention signals from complementary data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get("teleop_action", {})
is_intervention = complementary_data.get("is_intervention", False)
terminate_episode = complementary_data.get("terminate_episode", False)
success = complementary_data.get("success", False)
rerecord_episode = complementary_data.get("rerecord_episode", False)
new_transition = transition.copy()
# Override action if intervention is active
if is_intervention and teleop_action:
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("delta_x", 0.0),
teleop_action.get("delta_y", 0.0),
teleop_action.get("delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
new_transition[TransitionKey.ACTION] = teleop_action_tensor
# Handle episode termination
new_transition[TransitionKey.DONE] = bool(terminate_episode)
new_transition[TransitionKey.REWARD] = float(success)
# Update info with intervention metadata
info = new_transition.get(TransitionKey.INFO, {})
info["is_intervention"] = is_intervention
info["rerecord_episode"] = rerecord_episode
info["next.success"] = success if terminate_episode else info.get("next.success", False)
new_transition[TransitionKey.INFO] = info
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
TransitionKey.ACTION
]
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"use_gripper": self.use_gripper,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register("inverse_kinematics_processor")
class InverseKinematicsProcessor:
"""Convert end-effector space actions to joint space using inverse kinematics.
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
to joint space commands using forward and inverse kinematics. It maintains the current
end-effector pose and joint positions to compute the transformations.
"""
urdf_path: str
target_frame_name: str = "gripper_link"
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
end_effector_bounds: dict[str, list[float]] | None = None
max_gripper_pos: float = 30.0
env: gym.Env = None # Environment reference to get current state
# State tracking
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the kinematics module after dataclass initialization."""
if self.urdf_path:
self.kinematics = RobotKinematics(
urdf_path=self.urdf_path,
target_frame_name=self.target_frame_name,
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
action_np = action.detach().cpu().numpy().squeeze()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
raw_joint_positions = complementary_data.get("raw_joint_positions")
current_gripper_pos = raw_joint_positions[-1]
if self.current_joint_pos is None:
self.current_joint_pos = raw_joint_positions
# Initialize end-effector position if not available
if self.current_joint_pos is None:
return transition # Cannot proceed without joint positions
# Calculate current end-effector position using forward kinematics
if self.current_ee_pos is None:
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
# Scale deltas by step sizes
delta_ee = np.array(
[
action_np[0] * self.end_effector_step_sizes["x"],
action_np[1] * self.end_effector_step_sizes["y"],
action_np[2] * self.end_effector_step_sizes["z"],
],
dtype=np.float32,
)
# Set desired end-effector position by adding delta
desired_ee_pos = np.eye(4)
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
# Add delta to position and clip to bounds
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
if self.end_effector_bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3],
self.end_effector_bounds["min"],
self.end_effector_bounds["max"],
)
# Compute inverse kinematics to get joint positions
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
# Update current state
self.current_ee_pos = desired_ee_pos.copy()
self.current_joint_pos = target_joint_values.copy()
# Create new action with joint space commands
gripper_action = current_gripper_pos
if len(action_np) > 3:
# Handle gripper command separately
gripper_command = action_np[3]
# Process gripper command (convert from [0,2] to delta) and discretize
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
# Combine joint positions and gripper
target_joint_values[-1] = gripper_action
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = converted_action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"urdf_path": self.urdf_path,
"target_frame_name": self.target_frame_name,
"end_effector_step_sizes": self.end_effector_step_sizes,
"end_effector_bounds": self.end_effector_bounds,
"max_gripper_pos": self.max_gripper_pos,
}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
"""Reset the processor state."""
self.current_ee_pos = None
self.current_joint_pos = None
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
"""
Factory function to create a robot environment.
@@ -821,7 +357,6 @@ def make_processors(env, cfg):
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
end_effector_bounds=cfg.processor.end_effector_bounds,
max_gripper_pos=cfg.processor.max_gripper_pos,
env=env,
),
]
action_processor = RobotProcessor(steps=action_pipeline_steps)