Compare commits

..

5 Commits

Author SHA1 Message Date
Adil Zouitine f247aa0701 refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869)
- Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring.
2025-09-04 17:34:06 +02:00
Adil Zouitine 1ac6a6d3fe refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868)
- Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context.
- Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase.
2025-09-04 17:01:53 +02:00
Steven Palma e698c709d8 fix(deps): use in-house rotation utils over scipy throughout the codebase 2025-09-04 16:44:18 +02:00
Adil Zouitine a988da4789 feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866)
- Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events.
- Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method.
- Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators.
- Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators.
2025-09-04 16:28:49 +02:00
Adil Zouitine 99963b6968 refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863)
- Removed the scipy dependency from the project to streamline requirements.
- Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies.
- Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities.
2025-09-04 16:26:28 +02:00
24 changed files with 283 additions and 53 deletions
-1
View File
@@ -73,7 +73,6 @@ dependencies = [
"pynput>=1.7.7",
"pyserial>=3.5",
"wandb>=0.20.0",
"scipy>=1.15.2",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
+2 -2
View File
@@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
if "LEROBOT_HOME" in os.environ:
raise ValueError(
+3 -3
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -59,12 +59,12 @@ def make_act_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -16,7 +16,7 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -59,12 +59,12 @@ def make_diffusion_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+7 -2
View File
@@ -24,6 +24,7 @@ from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
@@ -148,14 +149,18 @@ def make_pre_post_processors(
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=preprocessor_kwargs.get("to_transition"),
to_output=preprocessor_kwargs.get("to_output"),
),
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=postprocessor_kwargs.get("to_transition"),
to_output=postprocessor_kwargs.get("to_output"),
+3 -3
View File
@@ -18,7 +18,7 @@
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -107,12 +107,12 @@ def make_pi0_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -16,7 +16,7 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -59,12 +59,12 @@ def make_pi0fast_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+3 -3
View File
@@ -17,7 +17,7 @@
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -60,12 +60,12 @@ def make_sac_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -17,7 +17,7 @@
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -70,12 +70,12 @@ def make_smolvla_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -16,7 +16,7 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -59,12 +59,12 @@ def make_tdmpc_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
@@ -17,7 +17,7 @@
# limitations under the License.
import torch
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -60,12 +60,12 @@ def make_vqbet_pre_post_processors(
return (
PolicyProcessorPipeline(
steps=input_steps,
name=PREPROCESSOR_DEFAULT_NAME,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
**preprocessor_kwargs,
),
PolicyProcessorPipeline(
steps=output_steps,
name=POSTPROCESSOR_DEFAULT_NAME,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
**postprocessor_kwargs,
),
)
+1 -1
View File
@@ -23,9 +23,9 @@ from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from lerobot.utils.rotation import Rotation
from .core import EnvTransition, TransitionKey
+54 -4
View File
@@ -1,7 +1,7 @@
import math
import time
from dataclasses import dataclass
from typing import Any
from typing import Any, Protocol, TypeVar, runtime_checkable
import numpy as np
import torch
@@ -27,6 +27,40 @@ DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@runtime_checkable
class HasTeleopEvents(Protocol):
"""Minimal protocol for objects that provide teleoperation events.
This protocol only defines the additional get_teleop_events() method,
avoiding duplication of the entire Teleoperator interface.
"""
def get_teleop_events(self) -> dict[str, Any]:
"""Get extra control events from the teleoperator.
Returns:
Dictionary containing control events such as:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
"""
...
# Type variable constrained to Teleoperator subclasses that also implement events
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""Runtime check that a teleoperator implements get_teleop_events."""
if not isinstance(teleop, HasTeleopEvents):
raise TypeError(
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
)
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
@@ -46,13 +80,29 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
"""Add teleoperator control events to transition info."""
"""Add teleoperator control events to transition info.
teleop_device: Teleoperator
This processor step extracts control events from teleoperators that support
event-based interaction (intervention detection, episode termination, etc.).
Works with any teleoperator that inherits from Teleoperator and implements the
get_teleop_events() method, including custom user-defined teleoperators.
Built-in compatible teleoperators:
- GamepadTeleop: Uses gamepad buttons for control events
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
"""
teleop_device: TeleopWithEvents
def __post_init__(self):
"""Validate that the teleoperator supports events."""
_check_teleop_with_events(self.teleop_device)
def info(self, info: dict) -> dict:
new_info = dict(info)
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
teleop_events = self.teleop_device.get_teleop_events()
new_info.update(teleop_events)
return new_info
@@ -17,7 +17,6 @@
from dataclasses import dataclass, field
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
@@ -32,6 +31,7 @@ from lerobot.processor import (
TransitionKey,
)
from lerobot.robots.robot import Robot
from lerobot.utils.rotation import Rotation
@ProcessorStepRegistry.register("ee_reference_and_delta")
+5 -3
View File
@@ -26,7 +26,7 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
@@ -153,9 +153,11 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
preprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
)
postprocessor.from_pretrained(
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
@@ -24,12 +24,12 @@ import time
import hebi
import numpy as np
from scipy.spatial.transform import Rotation
from teleop import Teleop
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.utils.rotation import Rotation
logger = logging.getLogger(__name__)
+174
View File
@@ -0,0 +1,174 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
import numpy as np
class Rotation:
"""
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
Supports conversions between rotation vectors, rotation matrices, and quaternions.
"""
def __init__(self, quat: np.ndarray) -> None:
"""Initialize rotation from quaternion [x, y, z, w]."""
self._quat = np.asarray(quat, dtype=float)
# Normalize quaternion
norm = np.linalg.norm(self._quat)
if norm > 0:
self._quat = self._quat / norm
@classmethod
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
"""
Create rotation from rotation vector using Rodrigues' formula.
Args:
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
Returns:
Rotation instance
"""
rotvec = np.asarray(rotvec, dtype=float)
angle = np.linalg.norm(rotvec)
if angle < 1e-8:
# For very small angles, use identity quaternion
quat = np.array([0.0, 0.0, 0.0, 1.0])
else:
axis = rotvec / angle
half_angle = angle / 2.0
sin_half = np.sin(half_angle)
cos_half = np.cos(half_angle)
# Quaternion [x, y, z, w]
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
return cls(quat)
@classmethod
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
"""
Create rotation from 3x3 rotation matrix.
Args:
matrix: 3x3 rotation matrix
Returns:
Rotation instance
"""
matrix = np.asarray(matrix, dtype=float)
# Shepherd's method for converting rotation matrix to quaternion
trace = np.trace(matrix)
if trace > 0:
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
qw = 0.25 * s
qx = (matrix[2, 1] - matrix[1, 2]) / s
qy = (matrix[0, 2] - matrix[2, 0]) / s
qz = (matrix[1, 0] - matrix[0, 1]) / s
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
qw = (matrix[2, 1] - matrix[1, 2]) / s
qx = 0.25 * s
qy = (matrix[0, 1] + matrix[1, 0]) / s
qz = (matrix[0, 2] + matrix[2, 0]) / s
elif matrix[1, 1] > matrix[2, 2]:
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
qw = (matrix[0, 2] - matrix[2, 0]) / s
qx = (matrix[0, 1] + matrix[1, 0]) / s
qy = 0.25 * s
qz = (matrix[1, 2] + matrix[2, 1]) / s
else:
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
qw = (matrix[1, 0] - matrix[0, 1]) / s
qx = (matrix[0, 2] + matrix[2, 0]) / s
qy = (matrix[1, 2] + matrix[2, 1]) / s
qz = 0.25 * s
quat = np.array([qx, qy, qz, qw])
return cls(quat)
@classmethod
def from_quat(cls, quat: np.ndarray) -> "Rotation":
"""
Create rotation from quaternion.
Args:
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
This implementation expects [x, y, z, w] format
Returns:
Rotation instance
"""
return cls(quat)
def as_matrix(self) -> np.ndarray:
"""
Convert rotation to 3x3 rotation matrix.
Returns:
3x3 rotation matrix
"""
qx, qy, qz, qw = self._quat
# Compute rotation matrix from quaternion
return np.array(
[
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
],
dtype=float,
)
def as_rotvec(self) -> np.ndarray:
"""
Convert rotation to rotation vector.
Returns:
Rotation vector [x, y, z] where magnitude is angle in radians
"""
qx, qy, qz, qw = self._quat
# Ensure qw is positive for unique representation
if qw < 0:
qx, qy, qz, qw = -qx, -qy, -qz, -qw
# Compute angle and axis
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
sin_half_angle = np.sqrt(1.0 - qw * qw)
if sin_half_angle < 1e-8:
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
return 2.0 * np.array([qx, qy, qz])
# Extract axis and scale by angle
axis = np.array([qx, qy, qz]) / sin_half_angle
return angle * axis
def as_quat(self) -> np.ndarray:
"""
Get quaternion representation.
Returns:
Quaternion [x, y, z, w]
"""
return self._quat.copy()
+2 -2
View File
@@ -81,8 +81,8 @@ def test_make_act_processor_basic():
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
+2 -2
View File
@@ -84,8 +84,8 @@ def test_make_diffusion_processor_basic():
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
+2 -2
View File
@@ -110,8 +110,8 @@ def test_make_pi0_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
+2 -2
View File
@@ -86,8 +86,8 @@ def test_make_sac_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
+2 -2
View File
@@ -117,8 +117,8 @@ def test_make_smolvla_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
+2 -2
View File
@@ -89,8 +89,8 @@ def test_make_tdmpc_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
+2 -2
View File
@@ -89,8 +89,8 @@ def test_make_vqbet_processor_basic():
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
assert postprocessor.name == "robot_postprocessor"
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 4