mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f247aa0701 | |||
| 1ac6a6d3fe | |||
| e698c709d8 | |||
| a988da4789 | |||
| 99963b6968 |
@@ -73,7 +73,6 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
|
||||
@@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -59,12 +59,12 @@ def make_act_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -59,12 +59,12 @@ def make_diffusion_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
@@ -148,14 +149,18 @@ def make_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -107,12 +107,12 @@ def make_pi0_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -59,12 +59,12 @@ def make_pi0fast_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -60,12 +60,12 @@ def make_sac_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -70,12 +70,12 @@ def make_smolvla_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -59,12 +59,12 @@ def make_tdmpc_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -60,12 +60,12 @@ def make_vqbet_pre_post_processors(
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -23,9 +23,9 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,6 +27,40 @@ DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasTeleopEvents(Protocol):
|
||||
"""Minimal protocol for objects that provide teleoperation events.
|
||||
|
||||
This protocol only defines the additional get_teleop_events() method,
|
||||
avoiding duplication of the entire Teleoperator interface.
|
||||
"""
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""Get extra control events from the teleoperator.
|
||||
|
||||
Returns:
|
||||
Dictionary containing control events such as:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""Runtime check that a teleoperator implements get_teleop_events."""
|
||||
if not isinstance(teleop, HasTeleopEvents):
|
||||
raise TypeError(
|
||||
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
|
||||
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
@@ -46,13 +80,29 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
"""Add teleoperator control events to transition info.
|
||||
|
||||
teleop_device: Teleoperator
|
||||
This processor step extracts control events from teleoperators that support
|
||||
event-based interaction (intervention detection, episode termination, etc.).
|
||||
|
||||
Works with any teleoperator that inherits from Teleoperator and implements the
|
||||
get_teleop_events() method, including custom user-defined teleoperators.
|
||||
|
||||
Built-in compatible teleoperators:
|
||||
- GamepadTeleop: Uses gamepad buttons for control events
|
||||
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
|
||||
"""
|
||||
|
||||
teleop_device: TeleopWithEvents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that the teleoperator supports events."""
|
||||
_check_teleop_with_events(self.teleop_device)
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
new_info = dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
|
||||
teleop_events = self.teleop_device.get_teleop_events()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
@@ -32,6 +31,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
|
||||
@@ -26,7 +26,7 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
@@ -153,9 +153,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
preprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
postprocessor.from_pretrained(
|
||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
|
||||
@@ -24,12 +24,12 @@ import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Rotation:
|
||||
"""
|
||||
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
|
||||
|
||||
Supports conversions between rotation vectors, rotation matrices, and quaternions.
|
||||
"""
|
||||
|
||||
def __init__(self, quat: np.ndarray) -> None:
|
||||
"""Initialize rotation from quaternion [x, y, z, w]."""
|
||||
self._quat = np.asarray(quat, dtype=float)
|
||||
# Normalize quaternion
|
||||
norm = np.linalg.norm(self._quat)
|
||||
if norm > 0:
|
||||
self._quat = self._quat / norm
|
||||
|
||||
@classmethod
|
||||
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from rotation vector using Rodrigues' formula.
|
||||
|
||||
Args:
|
||||
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
rotvec = np.asarray(rotvec, dtype=float)
|
||||
angle = np.linalg.norm(rotvec)
|
||||
|
||||
if angle < 1e-8:
|
||||
# For very small angles, use identity quaternion
|
||||
quat = np.array([0.0, 0.0, 0.0, 1.0])
|
||||
else:
|
||||
axis = rotvec / angle
|
||||
half_angle = angle / 2.0
|
||||
sin_half = np.sin(half_angle)
|
||||
cos_half = np.cos(half_angle)
|
||||
|
||||
# Quaternion [x, y, z, w]
|
||||
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
|
||||
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from 3x3 rotation matrix.
|
||||
|
||||
Args:
|
||||
matrix: 3x3 rotation matrix
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
matrix = np.asarray(matrix, dtype=float)
|
||||
|
||||
# Shepherd's method for converting rotation matrix to quaternion
|
||||
trace = np.trace(matrix)
|
||||
|
||||
if trace > 0:
|
||||
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
|
||||
qw = 0.25 * s
|
||||
qx = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qy = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qz = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
|
||||
qw = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qx = 0.25 * s
|
||||
qy = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qz = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
elif matrix[1, 1] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
|
||||
qw = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qx = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qy = 0.25 * s
|
||||
qz = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
else:
|
||||
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
|
||||
qw = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
qx = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
qy = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
qz = 0.25 * s
|
||||
|
||||
quat = np.array([qx, qy, qz, qw])
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_quat(cls, quat: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from quaternion.
|
||||
|
||||
Args:
|
||||
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
|
||||
This implementation expects [x, y, z, w] format
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
return cls(quat)
|
||||
|
||||
def as_matrix(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to 3x3 rotation matrix.
|
||||
|
||||
Returns:
|
||||
3x3 rotation matrix
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Compute rotation matrix from quaternion
|
||||
return np.array(
|
||||
[
|
||||
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
|
||||
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
|
||||
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
def as_rotvec(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to rotation vector.
|
||||
|
||||
Returns:
|
||||
Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Ensure qw is positive for unique representation
|
||||
if qw < 0:
|
||||
qx, qy, qz, qw = -qx, -qy, -qz, -qw
|
||||
|
||||
# Compute angle and axis
|
||||
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
|
||||
sin_half_angle = np.sqrt(1.0 - qw * qw)
|
||||
|
||||
if sin_half_angle < 1e-8:
|
||||
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
|
||||
return 2.0 * np.array([qx, qy, qz])
|
||||
|
||||
# Extract axis and scale by angle
|
||||
axis = np.array([qx, qy, qz]) / sin_half_angle
|
||||
return angle * axis
|
||||
|
||||
def as_quat(self) -> np.ndarray:
|
||||
"""
|
||||
Get quaternion representation.
|
||||
|
||||
Returns:
|
||||
Quaternion [x, y, z, w]
|
||||
"""
|
||||
return self._quat.copy()
|
||||
@@ -81,8 +81,8 @@ def test_make_act_processor_basic():
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
|
||||
@@ -84,8 +84,8 @@ def test_make_diffusion_processor_basic():
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
|
||||
@@ -110,8 +110,8 @@ def test_make_pi0_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
|
||||
@@ -86,8 +86,8 @@ def test_make_sac_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
|
||||
@@ -117,8 +117,8 @@ def test_make_smolvla_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
|
||||
@@ -89,8 +89,8 @@ def test_make_tdmpc_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
|
||||
@@ -89,8 +89,8 @@ def test_make_vqbet_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
|
||||
Reference in New Issue
Block a user