mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
change folder structure to reduce the size of gym_manip
This commit is contained in:
@@ -15,6 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .device_processor import DeviceProcessor
|
from .device_processor import DeviceProcessor
|
||||||
|
from .hil_processor import (
|
||||||
|
GripperPenaltyProcessor,
|
||||||
|
ImageCropResizeProcessor,
|
||||||
|
InterventionActionProcessor,
|
||||||
|
TimeLimitProcessor,
|
||||||
|
)
|
||||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||||
from .observation_processor import (
|
from .observation_processor import (
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
@@ -36,15 +42,26 @@ from .pipeline import (
|
|||||||
TruncatedProcessor,
|
TruncatedProcessor,
|
||||||
)
|
)
|
||||||
from .rename_processor import RenameProcessor
|
from .rename_processor import RenameProcessor
|
||||||
|
from .robot_processor import (
|
||||||
|
InverseKinematicsProcessor,
|
||||||
|
JointVelocityProcessor,
|
||||||
|
MotorCurrentProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ActionProcessor",
|
"ActionProcessor",
|
||||||
"DeviceProcessor",
|
"DeviceProcessor",
|
||||||
"DoneProcessor",
|
"DoneProcessor",
|
||||||
"EnvTransition",
|
"EnvTransition",
|
||||||
|
"GripperPenaltyProcessor",
|
||||||
"IdentityProcessor",
|
"IdentityProcessor",
|
||||||
|
"ImageCropResizeProcessor",
|
||||||
"ImageProcessor",
|
"ImageProcessor",
|
||||||
"InfoProcessor",
|
"InfoProcessor",
|
||||||
|
"InterventionActionProcessor",
|
||||||
|
"InverseKinematicsProcessor",
|
||||||
|
"JointVelocityProcessor",
|
||||||
|
"MotorCurrentProcessor",
|
||||||
"NormalizerProcessor",
|
"NormalizerProcessor",
|
||||||
"UnnormalizerProcessor",
|
"UnnormalizerProcessor",
|
||||||
"ObservationProcessor",
|
"ObservationProcessor",
|
||||||
@@ -54,6 +71,7 @@ __all__ = [
|
|||||||
"RewardProcessor",
|
"RewardProcessor",
|
||||||
"RobotProcessor",
|
"RobotProcessor",
|
||||||
"StateProcessor",
|
"StateProcessor",
|
||||||
|
"TimeLimitProcessor",
|
||||||
"TransitionKey",
|
"TransitionKey",
|
||||||
"TruncatedProcessor",
|
"TruncatedProcessor",
|
||||||
"VanillaObservationProcessor",
|
"VanillaObservationProcessor",
|
||||||
|
|||||||
@@ -0,0 +1,243 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||||
|
class ImageCropResizeProcessor:
|
||||||
|
"""Crop and resize image observations."""
|
||||||
|
|
||||||
|
crop_params_dict: dict[str, tuple[int, int, int, int]]
|
||||||
|
resize_size: tuple[int, int] = (128, 128)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
if self.resize_size is None and not self.crop_params_dict:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
new_observation = dict(observation)
|
||||||
|
|
||||||
|
# Process all image keys in the observation
|
||||||
|
for key in observation:
|
||||||
|
if "image" not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image = observation[key]
|
||||||
|
device = image.device
|
||||||
|
if device.type == "mps":
|
||||||
|
image = image.cpu()
|
||||||
|
# Crop if crop params are provided for this key
|
||||||
|
if key in self.crop_params_dict:
|
||||||
|
crop_params = self.crop_params_dict[key]
|
||||||
|
image = F.crop(image, *crop_params)
|
||||||
|
# Always resize
|
||||||
|
image = F.resize(image, self.resize_size)
|
||||||
|
image = image.clamp(0.0, 1.0)
|
||||||
|
new_observation[key] = image.to(device)
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"crop_params_dict": self.crop_params_dict,
|
||||||
|
"resize_size": self.resize_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("time_limit_processor")
|
||||||
|
class TimeLimitProcessor:
|
||||||
|
"""Track episode time and enforce time limits."""
|
||||||
|
|
||||||
|
max_episode_steps: int
|
||||||
|
current_step: int = 0
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||||
|
if truncated is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
self.current_step += 1
|
||||||
|
if self.current_step >= self.max_episode_steps:
|
||||||
|
truncated = True
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.TRUNCATED] = truncated
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"max_episode_steps": self.max_episode_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.current_step = 0
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||||
|
class GripperPenaltyProcessor:
|
||||||
|
penalty: float = -0.01
|
||||||
|
max_gripper_pos: float = 30.0
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Calculate gripper penalty and add to complementary data."""
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||||
|
|
||||||
|
if complementary_data is None or action is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
|
||||||
|
if current_gripper_pos is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
gripper_action = action[-1].item()
|
||||||
|
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||||
|
|
||||||
|
# Normalize gripper state and action
|
||||||
|
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||||
|
gripper_action_normalized = gripper_action - 1.0
|
||||||
|
|
||||||
|
# Calculate penalty boolean as in original
|
||||||
|
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||||
|
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||||
|
|
||||||
|
# Add penalty information to complementary data
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
|
||||||
|
# Create new complementary data with penalty info
|
||||||
|
new_complementary_data = dict(complementary_data)
|
||||||
|
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||||
|
|
||||||
|
# Create new transition with updated complementary data
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"penalty": self.penalty,
|
||||||
|
"max_gripper_pos": self.max_gripper_pos,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the processor state."""
|
||||||
|
self.last_gripper_state = None
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||||
|
class InterventionActionProcessor:
|
||||||
|
"""Handle action intervention based on signals in the transition.
|
||||||
|
|
||||||
|
This processor checks for intervention signals in the transition's complementary data
|
||||||
|
and overrides agent actions when intervention is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_gripper: bool = False
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
# Get intervention signals from complementary data
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
teleop_action = complementary_data.get("teleop_action", {})
|
||||||
|
is_intervention = complementary_data.get("is_intervention", False)
|
||||||
|
terminate_episode = complementary_data.get("terminate_episode", False)
|
||||||
|
success = complementary_data.get("success", False)
|
||||||
|
rerecord_episode = complementary_data.get("rerecord_episode", False)
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
|
||||||
|
# Override action if intervention is active
|
||||||
|
if is_intervention and teleop_action:
|
||||||
|
# Convert teleop_action dict to tensor format
|
||||||
|
action_list = [
|
||||||
|
teleop_action.get("delta_x", 0.0),
|
||||||
|
teleop_action.get("delta_y", 0.0),
|
||||||
|
teleop_action.get("delta_z", 0.0),
|
||||||
|
]
|
||||||
|
if self.use_gripper:
|
||||||
|
action_list.append(teleop_action.get("gripper", 1.0))
|
||||||
|
|
||||||
|
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||||
|
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||||
|
|
||||||
|
# Handle episode termination
|
||||||
|
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
||||||
|
new_transition[TransitionKey.REWARD] = float(success)
|
||||||
|
|
||||||
|
# Update info with intervention metadata
|
||||||
|
info = new_transition.get(TransitionKey.INFO, {})
|
||||||
|
info["is_intervention"] = is_intervention
|
||||||
|
info["rerecord_episode"] = rerecord_episode
|
||||||
|
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
||||||
|
new_transition[TransitionKey.INFO] = info
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
||||||
|
TransitionKey.ACTION
|
||||||
|
]
|
||||||
|
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"use_gripper": self.use_gripper,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
@@ -0,0 +1,245 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||||
|
class JointVelocityProcessor:
|
||||||
|
"""Add joint velocity information to observations.
|
||||||
|
|
||||||
|
Computes joint velocities by tracking changes in joint positions over time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
joint_velocity_limits: float = 100.0
|
||||||
|
dt: float = 1.0 / 10
|
||||||
|
|
||||||
|
last_joint_positions: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
# Get current joint positions (assuming they're in observation.state)
|
||||||
|
current_positions = observation.get("observation.state")
|
||||||
|
if current_positions is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
# Initialize last joint positions if not already set
|
||||||
|
if self.last_joint_positions is None:
|
||||||
|
self.last_joint_positions = current_positions.clone()
|
||||||
|
|
||||||
|
# Compute velocities
|
||||||
|
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||||
|
self.last_joint_positions = current_positions.clone()
|
||||||
|
|
||||||
|
# Extend observation with velocities
|
||||||
|
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||||
|
|
||||||
|
# Create new observation dict
|
||||||
|
new_observation = dict(observation)
|
||||||
|
new_observation["observation.state"] = extended_state
|
||||||
|
|
||||||
|
# Return new transition
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"joint_velocity_limits": self.joint_velocity_limits,
|
||||||
|
"dt": self.dt,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.last_joint_positions = None
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("current_processor")
|
||||||
|
class MotorCurrentProcessor:
|
||||||
|
"""Add motor current information to observations."""
|
||||||
|
|
||||||
|
env: gym.Env = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
observation = transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
# Get current values from complementary_data (where robot state would be stored)
|
||||||
|
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
|
||||||
|
motor_currents = torch.tensor(
|
||||||
|
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
|
||||||
|
dtype=torch.float32,
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
current_state = observation.get("observation.state")
|
||||||
|
if current_state is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||||
|
|
||||||
|
# Create new observation dict
|
||||||
|
new_observation = dict(observation)
|
||||||
|
new_observation["observation.state"] = extended_state
|
||||||
|
|
||||||
|
# Return new transition
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register("inverse_kinematics_processor")
|
||||||
|
class InverseKinematicsProcessor:
|
||||||
|
"""Convert end-effector space actions to joint space using inverse kinematics.
|
||||||
|
|
||||||
|
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
|
||||||
|
to joint space commands using forward and inverse kinematics. It maintains the current
|
||||||
|
end-effector pose and joint positions to compute the transformations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
urdf_path: str
|
||||||
|
target_frame_name: str = "gripper_link"
|
||||||
|
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
|
||||||
|
end_effector_bounds: dict[str, list[float]] | None = None
|
||||||
|
max_gripper_pos: float = 30.0
|
||||||
|
|
||||||
|
# State tracking
|
||||||
|
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize the kinematics module after dataclass initialization."""
|
||||||
|
if self.urdf_path:
|
||||||
|
self.kinematics = RobotKinematics(
|
||||||
|
urdf_path=self.urdf_path,
|
||||||
|
target_frame_name=self.target_frame_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
action_np = action.detach().cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||||
|
current_gripper_pos = raw_joint_positions[-1]
|
||||||
|
if self.current_joint_pos is None:
|
||||||
|
self.current_joint_pos = raw_joint_positions
|
||||||
|
|
||||||
|
# Initialize end-effector position if not available
|
||||||
|
if self.current_joint_pos is None:
|
||||||
|
return transition # Cannot proceed without joint positions
|
||||||
|
|
||||||
|
# Calculate current end-effector position using forward kinematics
|
||||||
|
if self.current_ee_pos is None:
|
||||||
|
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||||
|
|
||||||
|
# Scale deltas by step sizes
|
||||||
|
delta_ee = np.array(
|
||||||
|
[
|
||||||
|
action_np[0] * self.end_effector_step_sizes["x"],
|
||||||
|
action_np[1] * self.end_effector_step_sizes["y"],
|
||||||
|
action_np[2] * self.end_effector_step_sizes["z"],
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set desired end-effector position by adding delta
|
||||||
|
desired_ee_pos = np.eye(4)
|
||||||
|
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||||
|
|
||||||
|
# Add delta to position and clip to bounds
|
||||||
|
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
|
||||||
|
if self.end_effector_bounds is not None:
|
||||||
|
desired_ee_pos[:3, 3] = np.clip(
|
||||||
|
desired_ee_pos[:3, 3],
|
||||||
|
self.end_effector_bounds["min"],
|
||||||
|
self.end_effector_bounds["max"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute inverse kinematics to get joint positions
|
||||||
|
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
|
||||||
|
|
||||||
|
# Update current state
|
||||||
|
self.current_ee_pos = desired_ee_pos.copy()
|
||||||
|
self.current_joint_pos = target_joint_values.copy()
|
||||||
|
|
||||||
|
# Create new action with joint space commands
|
||||||
|
gripper_action = current_gripper_pos
|
||||||
|
if len(action_np) > 3:
|
||||||
|
# Handle gripper command separately
|
||||||
|
gripper_command = action_np[3]
|
||||||
|
|
||||||
|
# Process gripper command (convert from [0,2] to delta) and discretize
|
||||||
|
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
|
||||||
|
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
|
||||||
|
|
||||||
|
# Combine joint positions and gripper
|
||||||
|
target_joint_values[-1] = gripper_action
|
||||||
|
|
||||||
|
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_transition[TransitionKey.ACTION] = converted_action
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"urdf_path": self.urdf_path,
|
||||||
|
"target_frame_name": self.target_frame_name,
|
||||||
|
"end_effector_step_sizes": self.end_effector_step_sizes,
|
||||||
|
"end_effector_bounds": self.end_effector_bounds,
|
||||||
|
"max_gripper_pos": self.max_gripper_pos,
|
||||||
|
}
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the processor state."""
|
||||||
|
self.current_ee_pos = None
|
||||||
|
self.current_joint_pos = None
|
||||||
|
|
||||||
|
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
|
return features
|
||||||
@@ -16,27 +16,34 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F # noqa: N812
|
|
||||||
|
|
||||||
from lerobot.cameras import opencv # noqa: F401
|
from lerobot.cameras import opencv # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.types import PolicyFeature
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.envs.configs import EnvConfig
|
from lerobot.envs.configs import EnvConfig
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
DeviceProcessor,
|
DeviceProcessor,
|
||||||
ImageProcessor,
|
ImageProcessor,
|
||||||
RobotProcessor,
|
RobotProcessor,
|
||||||
StateProcessor,
|
StateProcessor,
|
||||||
)
|
)
|
||||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
from lerobot.processor.hil_processor import (
|
||||||
|
GripperPenaltyProcessor,
|
||||||
|
ImageCropResizeProcessor,
|
||||||
|
InterventionActionProcessor,
|
||||||
|
TimeLimitProcessor,
|
||||||
|
)
|
||||||
|
from lerobot.processor.pipeline import TransitionKey
|
||||||
|
from lerobot.processor.robot_processor import (
|
||||||
|
InverseKinematicsProcessor,
|
||||||
|
JointVelocityProcessor,
|
||||||
|
MotorCurrentProcessor,
|
||||||
|
)
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
@@ -279,477 +286,6 @@ class RobotEnv(gym.Env):
|
|||||||
self.robot.disconnect()
|
self.robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
|
||||||
class JointVelocityProcessor:
|
|
||||||
"""Add joint velocity information to observations.
|
|
||||||
|
|
||||||
Computes joint velocities by tracking changes in joint positions over time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
joint_velocity_limits: float = 100.0
|
|
||||||
dt: float = 1.0 / 10
|
|
||||||
|
|
||||||
last_joint_positions: torch.Tensor | None = None
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
if observation is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Get current joint positions (assuming they're in observation.state)
|
|
||||||
current_positions = observation.get("observation.state")
|
|
||||||
if current_positions is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Initialize last joint positions if not already set
|
|
||||||
if self.last_joint_positions is None:
|
|
||||||
self.last_joint_positions = current_positions.clone()
|
|
||||||
|
|
||||||
# Compute velocities
|
|
||||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
|
||||||
self.last_joint_positions = current_positions.clone()
|
|
||||||
|
|
||||||
# Extend observation with velocities
|
|
||||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
|
||||||
|
|
||||||
# Create new observation dict
|
|
||||||
new_observation = dict(observation)
|
|
||||||
new_observation["observation.state"] = extended_state
|
|
||||||
|
|
||||||
# Return new transition
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"joint_velocity_limits": self.joint_velocity_limits,
|
|
||||||
"fps": self.fps,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.last_joint_positions = None
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("current_processor")
|
|
||||||
class MotorCurrentProcessor:
|
|
||||||
"""Add motor current information to observations."""
|
|
||||||
|
|
||||||
env: gym.Env = None
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
if observation is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Get current values from complementary_data (where robot state would be stored)
|
|
||||||
present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current")
|
|
||||||
motor_currents = torch.tensor(
|
|
||||||
[present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors],
|
|
||||||
dtype=torch.float32,
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
current_state = observation.get("observation.state")
|
|
||||||
if current_state is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
|
||||||
|
|
||||||
# Create new observation dict
|
|
||||||
new_observation = dict(observation)
|
|
||||||
new_observation["observation.state"] = extended_state
|
|
||||||
|
|
||||||
# Return new transition
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
|
||||||
class ImageCropResizeProcessor:
|
|
||||||
"""Crop and resize image observations."""
|
|
||||||
|
|
||||||
crop_params_dict: dict[str, tuple[int, int, int, int]]
|
|
||||||
resize_size: tuple[int, int] = (128, 128)
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
|
||||||
if observation is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
if self.resize_size is None and not self.crop_params_dict:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
new_observation = dict(observation)
|
|
||||||
|
|
||||||
# Process all image keys in the observation
|
|
||||||
for key in observation:
|
|
||||||
if "image" not in key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
image = observation[key]
|
|
||||||
device = image.device
|
|
||||||
if device.type == "mps":
|
|
||||||
image = image.cpu()
|
|
||||||
# Crop if crop params are provided for this key
|
|
||||||
if key in self.crop_params_dict:
|
|
||||||
crop_params = self.crop_params_dict[key]
|
|
||||||
image = F.crop(image, *crop_params)
|
|
||||||
# Always resize
|
|
||||||
image = F.resize(image, self.resize_size)
|
|
||||||
image = image.clamp(0.0, 1.0)
|
|
||||||
new_observation[key] = image.to(device)
|
|
||||||
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"crop_params_dict": self.crop_params_dict,
|
|
||||||
"resize_size": self.resize_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("time_limit_processor")
|
|
||||||
class TimeLimitProcessor:
|
|
||||||
"""Track episode time and enforce time limits."""
|
|
||||||
|
|
||||||
max_episode_steps: int
|
|
||||||
current_step: int = 0
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
|
||||||
if truncated is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
self.current_step += 1
|
|
||||||
if self.current_step >= self.max_episode_steps:
|
|
||||||
truncated = True
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.TRUNCATED] = truncated
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"max_episode_steps": self.max_episode_steps,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.current_step = 0
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
|
||||||
class GripperPenaltyProcessor:
|
|
||||||
penalty: float = -0.01
|
|
||||||
max_gripper_pos: float = 30.0
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
"""Calculate gripper penalty and add to complementary data."""
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
|
||||||
|
|
||||||
if complementary_data is None or action is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None)[-1]
|
|
||||||
if current_gripper_pos is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
gripper_action = action[-1].item()
|
|
||||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
|
||||||
|
|
||||||
# Normalize gripper state and action
|
|
||||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
|
||||||
gripper_action_normalized = gripper_action - 1.0
|
|
||||||
|
|
||||||
# Calculate penalty boolean as in original
|
|
||||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
|
||||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
|
||||||
|
|
||||||
# Add penalty information to complementary data
|
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
|
||||||
|
|
||||||
# Create new complementary data with penalty info
|
|
||||||
new_complementary_data = dict(complementary_data)
|
|
||||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
|
||||||
|
|
||||||
# Create new transition with updated complementary data
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"penalty": self.penalty,
|
|
||||||
"max_gripper_pos": self.max_gripper_pos,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset the processor state."""
|
|
||||||
self.last_gripper_state = None
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
|
||||||
class InterventionActionProcessor:
|
|
||||||
"""Handle action intervention based on signals in the transition.
|
|
||||||
|
|
||||||
This processor checks for intervention signals in the transition's complementary data
|
|
||||||
and overrides agent actions when intervention is active.
|
|
||||||
"""
|
|
||||||
|
|
||||||
use_gripper: bool = False
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
|
||||||
if action is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# Get intervention signals from complementary data
|
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
|
||||||
teleop_action = complementary_data.get("teleop_action", {})
|
|
||||||
is_intervention = complementary_data.get("is_intervention", False)
|
|
||||||
terminate_episode = complementary_data.get("terminate_episode", False)
|
|
||||||
success = complementary_data.get("success", False)
|
|
||||||
rerecord_episode = complementary_data.get("rerecord_episode", False)
|
|
||||||
|
|
||||||
new_transition = transition.copy()
|
|
||||||
|
|
||||||
# Override action if intervention is active
|
|
||||||
if is_intervention and teleop_action:
|
|
||||||
# Convert teleop_action dict to tensor format
|
|
||||||
action_list = [
|
|
||||||
teleop_action.get("delta_x", 0.0),
|
|
||||||
teleop_action.get("delta_y", 0.0),
|
|
||||||
teleop_action.get("delta_z", 0.0),
|
|
||||||
]
|
|
||||||
if self.use_gripper:
|
|
||||||
action_list.append(teleop_action.get("gripper", 1.0))
|
|
||||||
|
|
||||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
|
||||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
|
||||||
|
|
||||||
# Handle episode termination
|
|
||||||
new_transition[TransitionKey.DONE] = bool(terminate_episode)
|
|
||||||
new_transition[TransitionKey.REWARD] = float(success)
|
|
||||||
|
|
||||||
# Update info with intervention metadata
|
|
||||||
info = new_transition.get(TransitionKey.INFO, {})
|
|
||||||
info["is_intervention"] = is_intervention
|
|
||||||
info["rerecord_episode"] = rerecord_episode
|
|
||||||
info["next.success"] = success if terminate_episode else info.get("next.success", False)
|
|
||||||
new_transition[TransitionKey.INFO] = info
|
|
||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] = new_transition[
|
|
||||||
TransitionKey.ACTION
|
|
||||||
]
|
|
||||||
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"use_gripper": self.use_gripper,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register("inverse_kinematics_processor")
|
|
||||||
class InverseKinematicsProcessor:
|
|
||||||
"""Convert end-effector space actions to joint space using inverse kinematics.
|
|
||||||
|
|
||||||
This processor transforms delta commands in end-effector space (delta_x, delta_y, delta_z)
|
|
||||||
to joint space commands using forward and inverse kinematics. It maintains the current
|
|
||||||
end-effector pose and joint positions to compute the transformations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
urdf_path: str
|
|
||||||
target_frame_name: str = "gripper_link"
|
|
||||||
end_effector_step_sizes: dict[str, float] = field(default_factory=lambda: {"x": 1.0, "y": 1.0, "z": 1.0})
|
|
||||||
end_effector_bounds: dict[str, list[float]] | None = None
|
|
||||||
max_gripper_pos: float = 30.0
|
|
||||||
env: gym.Env = None # Environment reference to get current state
|
|
||||||
|
|
||||||
# State tracking
|
|
||||||
current_ee_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
|
||||||
current_joint_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
|
||||||
kinematics: RobotKinematics | None = field(default=None, init=False, repr=False)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Initialize the kinematics module after dataclass initialization."""
|
|
||||||
if self.urdf_path:
|
|
||||||
self.kinematics = RobotKinematics(
|
|
||||||
urdf_path=self.urdf_path,
|
|
||||||
target_frame_name=self.target_frame_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
|
||||||
action = transition.get(TransitionKey.ACTION)
|
|
||||||
if action is None:
|
|
||||||
return transition
|
|
||||||
|
|
||||||
action_np = action.detach().cpu().numpy().squeeze()
|
|
||||||
|
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
|
||||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
|
||||||
current_gripper_pos = raw_joint_positions[-1]
|
|
||||||
if self.current_joint_pos is None:
|
|
||||||
self.current_joint_pos = raw_joint_positions
|
|
||||||
|
|
||||||
# Initialize end-effector position if not available
|
|
||||||
if self.current_joint_pos is None:
|
|
||||||
return transition # Cannot proceed without joint positions
|
|
||||||
|
|
||||||
# Calculate current end-effector position using forward kinematics
|
|
||||||
if self.current_ee_pos is None:
|
|
||||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
|
||||||
|
|
||||||
# Scale deltas by step sizes
|
|
||||||
delta_ee = np.array(
|
|
||||||
[
|
|
||||||
action_np[0] * self.end_effector_step_sizes["x"],
|
|
||||||
action_np[1] * self.end_effector_step_sizes["y"],
|
|
||||||
action_np[2] * self.end_effector_step_sizes["z"],
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set desired end-effector position by adding delta
|
|
||||||
desired_ee_pos = np.eye(4)
|
|
||||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
|
||||||
|
|
||||||
# Add delta to position and clip to bounds
|
|
||||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + delta_ee
|
|
||||||
if self.end_effector_bounds is not None:
|
|
||||||
desired_ee_pos[:3, 3] = np.clip(
|
|
||||||
desired_ee_pos[:3, 3],
|
|
||||||
self.end_effector_bounds["min"],
|
|
||||||
self.end_effector_bounds["max"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute inverse kinematics to get joint positions
|
|
||||||
target_joint_values = self.kinematics.inverse_kinematics(self.current_joint_pos, desired_ee_pos)
|
|
||||||
|
|
||||||
# Update current state
|
|
||||||
self.current_ee_pos = desired_ee_pos.copy()
|
|
||||||
self.current_joint_pos = target_joint_values.copy()
|
|
||||||
|
|
||||||
# Create new action with joint space commands
|
|
||||||
gripper_action = current_gripper_pos
|
|
||||||
if len(action_np) > 3:
|
|
||||||
# Handle gripper command separately
|
|
||||||
gripper_command = action_np[3]
|
|
||||||
|
|
||||||
# Process gripper command (convert from [0,2] to delta) and discretize
|
|
||||||
gripper_delta = np.round(gripper_command - 1.0).astype(int) * self.max_gripper_pos
|
|
||||||
gripper_action = np.clip(current_gripper_pos + gripper_delta, 0, self.max_gripper_pos)
|
|
||||||
|
|
||||||
# Combine joint positions and gripper
|
|
||||||
target_joint_values[-1] = gripper_action
|
|
||||||
|
|
||||||
converted_action = torch.from_numpy(target_joint_values).to(action.device).to(action.dtype)
|
|
||||||
|
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.ACTION] = converted_action
|
|
||||||
return new_transition
|
|
||||||
|
|
||||||
def get_config(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"urdf_path": self.urdf_path,
|
|
||||||
"target_frame_name": self.target_frame_name,
|
|
||||||
"end_effector_step_sizes": self.end_effector_step_sizes,
|
|
||||||
"end_effector_bounds": self.end_effector_bounds,
|
|
||||||
"max_gripper_pos": self.max_gripper_pos,
|
|
||||||
}
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset the processor state."""
|
|
||||||
self.current_ee_pos = None
|
|
||||||
self.current_joint_pos = None
|
|
||||||
|
|
||||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
|
def make_robot_env(cfg: EnvConfig) -> tuple[gym.Env, Any]:
|
||||||
"""
|
"""
|
||||||
Factory function to create a robot environment.
|
Factory function to create a robot environment.
|
||||||
@@ -821,7 +357,6 @@ def make_processors(env, cfg):
|
|||||||
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
end_effector_step_sizes=cfg.processor.end_effector_step_sizes,
|
||||||
end_effector_bounds=cfg.processor.end_effector_bounds,
|
end_effector_bounds=cfg.processor.end_effector_bounds,
|
||||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||||
env=env,
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
action_processor = RobotProcessor(steps=action_pipeline_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user