From 366bef915cdff80d0a1bcf4834130eb63a479cb5 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:26:49 +0000 Subject: [PATCH 01/43] add task ids to libero env cfg (#2842) --- src/lerobot/envs/configs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 112d3a73f..cd88b37bc 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig): @dataclass class LiberoEnv(EnvConfig): task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + task_ids: list[int] | None = None fps: int = 30 episode_length: int | None = None obs_type: str = "pixels_agent_pos" @@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig): @property def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - } + kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode} + if self.task_ids is not None: + kwargs["task_ids"] = self.task_ids + return kwargs @EnvConfig.register_subclass("metaworld") From 9cfb5ce5468d0f2df568f7ba3f902f13f89802a7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 26 Jan 2026 17:53:25 +0100 Subject: [PATCH 02/43] feat(motors): add damiao motors & can bus (#2788) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Pepijn --- docs/source/_toctree.yml | 2 + docs/source/damiao.mdx | 165 +++++ pyproject.toml | 3 + src/lerobot/motors/__init__.py | 6 +- src/lerobot/motors/calibration_gui.py | 2 +- src/lerobot/motors/damiao/__init__.py | 18 + src/lerobot/motors/damiao/damiao.py | 808 ++++++++++++++++++++++ src/lerobot/motors/damiao/tables.py | 209 ++++++ src/lerobot/motors/dynamixel/dynamixel.py | 11 +- src/lerobot/motors/feetech/feetech.py | 13 +- src/lerobot/motors/motors_bus.py | 101 ++- src/lerobot/scripts/lerobot_setup_can.py | 360 ++++++++++ src/lerobot/utils/import_utils.py | 1 + tests/motors/test_damiao.py | 66 ++ 14 files changed, 1740 insertions(+), 25 deletions(-) create mode 100644 docs/source/damiao.mdx create mode 100644 src/lerobot/motors/damiao/__init__.py create mode 100644 src/lerobot/motors/damiao/damiao.py create mode 100644 src/lerobot/motors/damiao/tables.py create mode 100644 src/lerobot/scripts/lerobot_setup_can.py create mode 100644 tests/motors/test_damiao.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4298758b5..f86dd11c7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -115,6 +115,8 @@ title: Notebooks - local: feetech title: Updating Feetech Firmware + - local: damiao + title: Damiao Motors and CAN Bus title: "Resources" - sections: - local: contributing diff --git a/docs/source/damiao.mdx b/docs/source/damiao.mdx new file mode 100644 index 000000000..45388ab9b --- /dev/null +++ b/docs/source/damiao.mdx @@ -0,0 +1,165 @@ +# Damiao Motors and CAN Bus + +This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication. + +Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux. + +## Linux CAN Setup + +Before using Damiao motors, you need to set up the CAN interface on your Linux system. + +### Install CAN Utilities + +```bash +sudo apt-get install can-utils +``` + +### Configure CAN Interface (Manual) + +For standard CAN FD (recommended for OpenArms): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on +sudo ip link set can0 up +``` + +For standard CAN (without FD): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 +sudo ip link set can0 up +``` + +### Configure CAN Interface (Using LeRobot) + +LeRobot provides a utility script to setup and test CAN interfaces: + +```bash +# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses) +lerobot-setup-can --mode=setup --interfaces=can0,can1 +``` + +## Debugging CAN Communication + +Use the built-in debug tools to test motor communication: + +```bash +# Test motors on all interfaces +lerobot-setup-can --mode=test --interfaces=can0,can1 + +# Run speed/latency test +lerobot-setup-can --mode=speed --interfaces=can0 +``` + +The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output: + +``` +can0: UP (CAN FD) + Motor 0x01 (joint_1): ✓ FOUND + → Response 0x11 [FD]: 00112233... + Motor 0x02 (joint_2): ✓ FOUND + Motor 0x03 (joint_3): ✗ No response + ... + Summary: 2/8 motors found +``` + +## Usage + +### Basic Setup + +```python +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + +# Define your motors with send/receive CAN IDs +motors = { + "joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11), + "joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12), + "joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13), +} + +# Create the bus +bus = DamiaoMotorsBus( + port="can0", # Linux socketcan interface + motors=motors, +) + +# Connect +bus.connect() +``` + +### Reading Motor States + +```python +# Read single motor position (degrees) +position = bus.read("Present_Position", "joint_1") + +# Read from multiple motors +positions = bus.sync_read("Present_Position") # All motors +positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"]) + +# Read all states at once (position, velocity, torque) +states = bus.sync_read_all_states() +# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} +``` + +### Writing Motor Commands + +```python +# Enable torque +bus.enable_torque() + +# Set goal position (degrees) +bus.write("Goal_Position", "joint_1", 45.0) + +# Set positions for multiple motors +bus.sync_write("Goal_Position", { + "joint_1": 45.0, + "joint_2": -30.0, + "joint_3": 90.0, +}) + +# Disable torque +bus.disable_torque() +``` + +## Configuration Options + +| Parameter | Default | Description | +| -------------- | --------- | ----------------------------------------------------------- | +| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +Each motor requires: + +- `id`: CAN ID for sending commands +- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`) +- `recv_id`: CAN ID for receiving responses + +OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number. + +## Troubleshooting + +### No Response from Motors + +1. **Check power** +2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections +3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs +4. **Test CAN interface**: Run `candump can0` to see if messages are being received +5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0` + +### Motor Timeout Parameter + +If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value. + +### Verify CAN FD Status + +```bash +ip -d link show can0 | grep fd +``` diff --git a/pyproject.toml b/pyproject.toml index 75f617e75..27126f855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] +damiao = ["python-can>=4.2.0,<5.0.0"] # Robots gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] @@ -203,6 +204,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" +lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] @@ -278,6 +280,7 @@ default.extend-ignore-identifiers-re = [ "thw", "inpt", "ROBOTIS", + "OT_VALUE" ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index 850ef33d7..5df80d5ba 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -14,4 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus +from .motors_bus import ( + Motor, + MotorCalibration, + MotorNormMode, +) diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 9832a1636..02bba454f 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -18,7 +18,7 @@ from dataclasses import dataclass os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" -from lerobot.motors import MotorCalibration, MotorsBus +from .motors_bus import MotorCalibration, MotorsBus BAR_LEN, BAR_THICKNESS = 450, 8 HANDLE_R = 10 diff --git a/src/lerobot/motors/damiao/__init__.py b/src/lerobot/motors/damiao/__init__.py new file mode 100644 index 000000000..8240138cf --- /dev/null +++ b/src/lerobot/motors/damiao/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .damiao import DamiaoMotorsBus +from .tables import * diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py new file mode 100644 index 000000000..dd0213fc3 --- /dev/null +++ b/src/lerobot/motors/damiao/damiao.py @@ -0,0 +1,808 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this file are derived from DM_Control_Python by cmjang. +# Licensed under the MIT License; see `LICENSE` for the full text: +# https://github.com/cmjang/DM_Control_Python + +import logging +import time +from contextlib import contextmanager +from copy import deepcopy +from functools import cached_property +from typing import TYPE_CHECKING, Any, TypedDict + +from lerobot.utils.import_utils import _can_available + +if TYPE_CHECKING or _can_available: + import can +else: + can.Message = object + can.interface = None + +import numpy as np + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value +from .tables import ( + AVAILABLE_BAUDRATES, + CAN_CMD_DISABLE, + CAN_CMD_ENABLE, + CAN_CMD_REFRESH, + CAN_CMD_SET_ZERO, + CAN_PARAM_ID, + DEFAULT_BAUDRATE, + DEFAULT_TIMEOUT_MS, + MIT_KD_RANGE, + MIT_KP_RANGE, + MOTOR_LIMIT_PARAMS, + MotorType, +) + +logger = logging.getLogger(__name__) + + +LONG_TIMEOUT_SEC = 0.1 +MEDIUM_TIMEOUT_SEC = 0.01 +SHORT_TIMEOUT_SEC = 0.001 +PRECISE_TIMEOUT_SEC = 0.0001 + + +class MotorState(TypedDict): + position: float + velocity: float + torque: float + temp_mos: float + temp_rotor: float + + +class DamiaoMotorsBus(MotorsBusBase): + """ + The Damiao implementation for a MotorsBus using CAN bus communication. + + This class uses python-can for CAN bus communication with Damiao motors. + For more info, see: + - python-can documentation: https://python-can.readthedocs.io/en/stable/ + - Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/ + - DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python + """ + + # CAN-specific settings + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + can_interface: str = "auto", + use_can_fd: bool = True, + bitrate: int = 1000000, + data_bitrate: int | None = 5000000, + ): + """ + Initialize the Damiao motors bus. + + Args: + port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS) + motors: Dictionary mapping motor names to Motor objects + calibration: Optional calibration data + can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial) + use_can_fd: Whether to use CAN FD mode (default: True for OpenArms) + bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps) + data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False + """ + super().__init__(port, motors, calibration) + self.port = port + self.can_interface = can_interface + self.use_can_fd = use_can_fd + self.bitrate = bitrate + self.data_bitrate = data_bitrate + self.canbus: can.interface.Bus | None = None + self._is_connected = False + + # Map motor names to CAN IDs + self._motor_can_ids: dict[str, int] = {} + self._recv_id_to_motor: dict[int, str] = {} + self._motor_types: dict[str, MotorType] = {} + + for name, motor in self.motors.items(): + if motor.motor_type_str is None: + raise ValueError(f"Motor '{name}' is missing required 'motor_type'") + self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_")) + + # Map recv_id to motor name for filtering responses + if motor.recv_id is not None: + self._recv_id_to_motor[motor.recv_id] = name + + # State cache for handling packet drops safely + self._last_known_states: dict[str, MotorState] = { + name: { + "position": 0.0, + "velocity": 0.0, + "torque": 0.0, + "temp_mos": 0.0, + "temp_rotor": 0.0, + } + for name in self.motors + } + + # Dynamic gains storage + # Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping) + self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors} + + @property + def is_connected(self) -> bool: + """Check if the CAN bus is connected.""" + return self._is_connected and self.canbus is not None + + def connect(self, handshake: bool = True) -> None: + """ + Open the CAN bus and initialize communication. + + Args: + handshake: If True, ping all motors to verify they're present + """ + if self.is_connected: + raise DeviceAlreadyConnectedError( + f"{self.__class__.__name__}('{self.port}') is already connected." + ) + + try: + # Auto-detect interface type based on port name + if self.can_interface == "auto": + if self.port.startswith("/dev/"): + self.can_interface = "slcan" + logger.info(f"Auto-detected slcan interface for port {self.port}") + else: + self.can_interface = "socketcan" + logger.info(f"Auto-detected socketcan interface for port {self.port}") + + # Connect to CAN bus + kwargs = { + "channel": self.port, + "bitrate": self.bitrate, + "interface": self.can_interface, + } + + if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None: + kwargs.update({"data_bitrate": self.data_bitrate, "fd": True}) + logger.info( + f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})" + ) + else: + logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})") + + self.canbus = can.interface.Bus(**kwargs) + self._is_connected = True + + if handshake: + self._handshake() + + logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.") + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e + + def _handshake(self) -> None: + """ + Verify all motors are present and populate initial state cache. + Raises ConnectionError if any motor fails to respond. + """ + logger.info("Starting handshake with motors...") + missing_motors = [] + + for motor_name in self.motors: + msg = self._refresh_motor(motor_name) + if msg is None: + missing_motors.append(motor_name) + else: + self._process_response(motor_name, msg) + time.sleep(MEDIUM_TIMEOUT_SEC) + + if missing_motors: + raise ConnectionError( + f"Handshake failed. The following motors did not respond: {missing_motors}. " + "Check power (24V) and CAN wiring." + ) + logger.info("Handshake successful. All motors ready.") + + def disconnect(self, disable_torque: bool = True) -> None: + """ + Close the CAN bus connection. + + Args: + disable_torque: If True, disable torque on all motors before disconnecting + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") + + if disable_torque: + try: + self.disable_torque() + except Exception as e: + logger.warning(f"Failed to disable torque during disconnect: {e}") + + if self.canbus: + self.canbus.shutdown() + self.canbus = None + self._is_connected = False + logger.debug(f"{self.__class__.__name__} disconnected.") + + def configure_motors(self) -> None: + """Configure all motors with default settings.""" + # Damiao motors don't require much configuration in MIT mode + # Just ensure they're enabled + for motor in self.motors: + self._send_simple_command(motor, CAN_CMD_ENABLE) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None: + """Helper to send simple 8-byte commands (Enable, Disable, Zero).""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [command_byte] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}") + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_ENABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_DISABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + @contextmanager + def torque_disabled(self, motors: str | list[str] | None = None): + """ + Context manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + """ + self.disable_torque(motors) + try: + yield + finally: + self.enable_torque(motors) + + def set_zero_position(self, motors: str | list[str] | None = None) -> None: + """Set current position as zero for selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + self._send_simple_command(motor, CAN_CMD_SET_ZERO) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _refresh_motor(self, motor: NameOrID) -> can.Message | None: + """Refresh motor status and return the response.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + return self._recv_motor_response(expected_recv_id=recv_id) + + def _recv_motor_response( + self, expected_recv_id: int | None = None, timeout: float = 0.001 + ) -> can.Message | None: + """ + Receive a response from a motor. + + Args: + expected_recv_id: If provided, only return messages from this CAN ID + timeout: Timeout in seconds (default: 1ms for high-speed operation) + Returns: + CAN message if received, None otherwise + """ + try: + start_time = time.time() + messages_seen = [] + while time.time() - start_time < timeout: + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg: + messages_seen.append(f"0x{msg.arbitration_id:02X}") + if expected_recv_id is None or msg.arbitration_id == expected_recv_id: + return msg + logger.debug( + f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}" + ) + + if logger.isEnabledFor(logging.DEBUG): + if messages_seen: + logger.debug( + f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}" + ) + else: + logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})") + except Exception as e: + logger.debug(f"Failed to receive CAN message: {e}") + return None + + def _recv_all_responses( + self, expected_recv_ids: list[int], timeout: float = 0.002 + ) -> dict[int, can.Message]: + """ + Efficiently receive responses from multiple motors at once. + Uses the OpenArms pattern: collect all available messages within timeout. + + Args: + expected_recv_ids: List of CAN IDs we expect responses from + timeout: Total timeout in seconds (default: 2ms) + + Returns: + Dictionary mapping recv_id to CAN message + """ + responses = {} + expected_set = set(expected_recv_ids) + start_time = time.time() + + try: + while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: + # 100us poll timeout + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg and msg.arbitration_id in expected_set: + responses[msg.arbitration_id] = msg + if len(responses) == len(expected_recv_ids): + break + except Exception as e: + logger.debug(f"Error receiving responses: {e}") + + return responses + + def _encode_mit_packet( + self, + motor_type: MotorType, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> list[int]: + """Helper to encode control parameters into 8 bytes for MIT mode.""" + # Convert degrees to radians + position_rad = np.radians(position_degrees) + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Encode parameters + kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12) + kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12) + q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16) + dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12) + tau_uint = self._float_to_uint(torque, -tmax, tmax, 12) + + # Pack data + data = [0] * 8 + data[0] = (q_uint >> 8) & 0xFF + data[1] = q_uint & 0xFF + data[2] = dq_uint >> 4 + data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF) + data[4] = kp_uint & 0xFF + data[5] = kd_uint >> 4 + data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF) + data[7] = tau_uint & 0xFF + return data + + def _mit_control( + self, + motor: NameOrID, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> None: + """Send MIT control command to a motor.""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id = self._get_motor_recv_id(motor) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after MIT control command") + + def _mit_control_batch( + self, + commands: dict[NameOrID, tuple[float, float, float, float, float]], + ) -> None: + """ + Send MIT control commands to multiple motors in batch. + Sends all commands first, then collects responses. + + Args: + commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque) + Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...} + """ + if not commands: + return + + recv_id_to_motor: dict[int, str] = {} + + # Step 1: Send all MIT control commands + for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + + def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: + """Convert float to unsigned integer for CAN transmission.""" + x = max(x_min, min(x_max, x)) # Clamp to range + span = x_max - x_min + data_norm = (x - x_min) / span + return int(data_norm * ((1 << bits) - 1)) + + def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float: + """Convert unsigned integer from CAN to float.""" + span = x_max - x_min + data_norm = float(x) / ((1 << bits) - 1) + return data_norm * span + x_min + + def _decode_motor_state( + self, data: bytearray | bytes, motor_type: MotorType + ) -> tuple[float, float, float, int, int]: + """ + Decode motor state from CAN data. + Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor) + """ + if len(data) < 8: + raise ValueError("Invalid motor state data") + + # Extract encoded values + q_uint = (data[1] << 8) | data[2] + dq_uint = (data[3] << 4) | (data[4] >> 4) + tau_uint = ((data[4] & 0x0F) << 8) | data[5] + t_mos = data[6] + t_rotor = data[7] + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Decode to physical values + position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16) + velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12) + torque = self._uint_to_float(tau_uint, -tmax, tmax, 12) + + return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor + + def _process_response(self, motor: str, msg: can.Message) -> None: + """Decode a message and update the motor state cache.""" + try: + motor_type = self._motor_types[motor] + pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type) + + self._last_known_states[motor] = { + "position": pos, + "velocity": vel, + "torque": torque, + "temp_mos": float(t_mos), + "temp_rotor": float(t_rotor), + } + except Exception as e: + logger.warning(f"Failed to decode response from {motor}: {e}") + + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor. Positions are always in degrees.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Refresh motor to get latest state + msg = self._refresh_motor(motor) + if msg is None: + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + raise ConnectionError( + f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). " + f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, " + f"3) Motor IDs are configured correctly using Damiao Debugging Tools" + ) + + self._process_response(motor, msg) + return self._get_cached_value(motor, data_name) + + def _get_cached_value(self, motor: str, data_name: str) -> Value: + """Retrieve a specific value from the cache.""" + state = self._last_known_states[motor] + mapping: dict[str, Any] = { + "Present_Position": state["position"], + "Present_Velocity": state["velocity"], + "Present_Torque": state["torque"], + "Temperature_MOS": state["temp_mos"], + "Temperature_Rotor": state["temp_rotor"], + } + if data_name not in mapping: + raise ValueError(f"Unknown data_name: {data_name}") + return mapping[data_name] + + def write( + self, + data_name: str, + motor: str, + value: Value, + ) -> None: + """ + Write a value to a single motor. Positions are always in degrees. + Can write 'Goal_Position', 'Kp', or 'Kd'. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if data_name in ("Kp", "Kd"): + self._gains[motor][data_name.lower()] = float(value) + elif data_name == "Goal_Position": + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + self._mit_control(motor, kp, kd, float(value), 0.0, 0.0) + else: + raise ValueError(f"Writing {data_name} not supported in MIT mode") + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + ) -> dict[str, Value]: + """ + Read the same value from multiple motors simultaneously. + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._get_cached_value(motor, data_name) + return result + + def sync_read_all_states( + self, + motors: str | list[str] | None = None, + *, + num_retry: int = 0, + ) -> dict[str, MotorState]: + """ + Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle. + + Returns: + Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque' + Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._last_known_states[motor].copy() + return result + + def _batch_refresh(self, motors: list[str]) -> None: + """Internal helper to refresh a list of motors and update cache.""" + # Send refresh commands + for motor in motors: + motor_id = self._get_motor_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + # Small delay to reduce bus congestion if necessary, though removed in sync_read previously + # precise_sleep(PRECISE_SLEEP_SEC) + + # Collect responses + expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] + responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC) + + # Update cache + for motor in motors: + recv_id = self._get_motor_recv_id(motor) + msg = responses.get(recv_id) + if msg: + self._process_response(motor, msg) + else: + logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """ + Write values to multiple motors simultaneously. Positions are always in degrees. + """ + if data_name in ("Kp", "Kd"): + key = data_name.lower() + for motor, val in values.items(): + self._gains[motor][key] = float(val) + + elif data_name == "Goal_Position": + # Step 1: Send all MIT control commands + recv_id_to_motor: dict[int, str] = {} + for motor, value_degrees in values.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + + data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + precise_sleep(PRECISE_TIMEOUT_SEC) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + else: + # Fall back to individual writes + for motor, value in values.items(): + self.write(data_name, motor, value) + + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration data from motors.""" + # Damiao motors don't store calibration internally + # Return existing calibration or empty dict + return self.calibration if self.calibration else {} + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration data to motors.""" + # Damiao motors don't store calibration internally + # Just cache it in memory + if cache: + self.calibration = calibration_dict + + def record_ranges_of_motion( + self, + motors: NameOrID | list[NameOrID] | None = None, + display_values: bool = True, + ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + """ + Interactively record the min/max values of each motor in degrees. + + Move the joints by hand (with torque disabled) while the method streams live positions. + Press Enter to finish. + """ + target_motors = self._get_motors_list(motors) + + self.disable_torque(target_motors) + time.sleep(LONG_TIMEOUT_SEC) + + start_positions = self.sync_read("Present_Position", target_motors) + mins = start_positions.copy() + maxes = start_positions.copy() + + print("\nMove joints through their full range of motion. Press ENTER when done.") + user_pressed_enter = False + + while not user_pressed_enter: + positions = self.sync_read("Present_Position", target_motors) + + for motor in target_motors: + if motor in positions: + mins[motor] = min(positions[motor], mins.get(motor, positions[motor])) + maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor])) + + if display_values: + print("\n" + "=" * 50) + print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}") + print("-" * 50) + for motor in target_motors: + if motor in positions: + print( + f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + move_cursor_up(len(target_motors) + 4) + + time.sleep(LONG_TIMEOUT_SEC) + + self.enable_torque(target_motors) + + for motor in target_motors: + if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5): + raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)") + + return mins, maxes + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + """Convert motor specification to list of motor names.""" + if motors is None: + return list(self.motors.keys()) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors + else: + raise TypeError(f"Invalid motors type: {type(motors)}") + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get CAN ID for a motor.""" + if isinstance(motor, str): + if motor in self.motors: + return self.motors[motor].id + else: + raise ValueError(f"Unknown motor: {motor}") + else: + return motor + + def _get_motor_name(self, motor: NameOrID) -> str: + """Get motor name from name or ID.""" + if isinstance(motor, str): + return motor + else: + for name, m in self.motors.items(): + if m.id == motor: + return name + raise ValueError(f"Unknown motor ID: {motor}") + + def _get_motor_recv_id(self, motor: NameOrID) -> int: + """Get motor recv_id from name or ID.""" + motor_name = self._get_motor_name(motor) + motor_obj = self.motors.get(motor_name) + if motor_obj and motor_obj.recv_id is not None: + return motor_obj.recv_id + else: + raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).") + + @cached_property + def is_calibrated(self) -> bool: + """Check if motors are calibrated.""" + return bool(self.calibration) diff --git a/src/lerobot/motors/damiao/tables.py b/src/lerobot/motors/damiao/tables.py new file mode 100644 index 000000000..22d1624fa --- /dev/null +++ b/src/lerobot/motors/damiao/tables.py @@ -0,0 +1,209 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration tables for Damiao motors.""" + +from enum import IntEnum + + +# Motor type definitions +class MotorType(IntEnum): + DM3507 = 0 + DM4310 = 1 + DM4310_48V = 2 + DM4340 = 3 + DM4340_48V = 4 + DM6006 = 5 + DM8006 = 6 + DM8009 = 7 + DM10010L = 8 + DM10010 = 9 + DMH3510 = 10 + DMH6215 = 11 + DMG6220 = 12 + + +# Control modes +class ControlMode(IntEnum): + MIT = 1 + POS_VEL = 2 + VEL = 3 + TORQUE_POS = 4 + + +# Motor variable IDs (RID) +class MotorVariable(IntEnum): + UV_VALUE = 0 + KT_VALUE = 1 + OT_VALUE = 2 + OC_VALUE = 3 + ACC = 4 + DEC = 5 + MAX_SPD = 6 + MST_ID = 7 + ESC_ID = 8 + TIMEOUT = 9 + CTRL_MODE = 10 + DAMP = 11 + INERTIA = 12 + HW_VER = 13 + SW_VER = 14 + SN = 15 + NPP = 16 + RS = 17 + LS = 18 + FLUX = 19 + GR = 20 + PMAX = 21 + VMAX = 22 + TMAX = 23 + I_BW = 24 + KP_ASR = 25 + KI_ASR = 26 + KP_APR = 27 + KI_APR = 28 + OV_VALUE = 29 + GREF = 30 + DETA = 31 + V_BW = 32 + IQ_C1 = 33 + VL_C1 = 34 + CAN_BR = 35 + SUB_VER = 36 + U_OFF = 50 + V_OFF = 51 + K1 = 52 + K2 = 53 + M_OFF = 54 + DIR = 55 + P_M = 80 + XOUT = 81 + + +# Motor limit parameters [PMAX, VMAX, TMAX] +# PMAX: Maximum position (rad) +# VMAX: Maximum velocity (rad/s) +# TMAX: Maximum torque (N·m) +MOTOR_LIMIT_PARAMS = { + MotorType.DM3507: (12.5, 30, 10), + MotorType.DM4310: (12.5, 30, 10), + MotorType.DM4310_48V: (12.5, 50, 10), + MotorType.DM4340: (12.5, 8, 28), + MotorType.DM4340_48V: (12.5, 10, 28), + MotorType.DM6006: (12.5, 45, 20), + MotorType.DM8006: (12.5, 45, 40), + MotorType.DM8009: (12.5, 45, 54), + MotorType.DM10010L: (12.5, 25, 200), + MotorType.DM10010: (12.5, 20, 200), + MotorType.DMH3510: (12.5, 280, 1), + MotorType.DMH6215: (12.5, 45, 10), + MotorType.DMG6220: (12.5, 45, 10), +} + +# Motor model names +MODEL_NAMES = { + MotorType.DM3507: "dm3507", + MotorType.DM4310: "dm4310", + MotorType.DM4310_48V: "dm4310_48v", + MotorType.DM4340: "dm4340", + MotorType.DM4340_48V: "dm4340_48v", + MotorType.DM6006: "dm6006", + MotorType.DM8006: "dm8006", + MotorType.DM8009: "dm8009", + MotorType.DM10010L: "dm10010l", + MotorType.DM10010: "dm10010", + MotorType.DMH3510: "dmh3510", + MotorType.DMH6215: "dmh6215", + MotorType.DMG6220: "dmg6220", +} + +# Motor resolution table (encoder counts per revolution) +MODEL_RESOLUTION = { + "dm3507": 65536, + "dm4310": 65536, + "dm4310_48v": 65536, + "dm4340": 65536, + "dm4340_48v": 65536, + "dm6006": 65536, + "dm8006": 65536, + "dm8009": 65536, + "dm10010l": 65536, + "dm10010": 65536, + "dmh3510": 65536, + "dmh6215": 65536, + "dmg6220": 65536, +} + +# CAN baudrates supported by Damiao motors +AVAILABLE_BAUDRATES = [ + 125000, # 0: 125 kbps + 200000, # 1: 200 kbps + 250000, # 2: 250 kbps + 500000, # 3: 500 kbps + 1000000, # 4: 1 mbps (default for OpenArms) + 2000000, # 5: 2 mbps + 2500000, # 6: 2.5 mbps + 3200000, # 7: 3.2 mbps + 4000000, # 8: 4 mbps + 5000000, # 9: 5 mbps +] +DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms + +# Default timeout in milliseconds +DEFAULT_TIMEOUT_MS = 1000 + +# OpenArms specific configurations +# Based on: https://docs.openarm.dev/software/setup/configure-test +# OpenArms has 7 DOF per arm (14 total for dual arm) +OPENARMS_ARM_MOTOR_IDS = { + "joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan + "joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift + "joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex + "joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex + "joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll + "joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch + "joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation +} + +OPENARMS_GRIPPER_MOTOR_IDS = { + "gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper +} + +# Default motor types for OpenArms +OPENARMS_DEFAULT_MOTOR_TYPES = { + "joint_1": MotorType.DM8009, # Shoulder pan - high torque + "joint_2": MotorType.DM8009, # Shoulder lift - high torque + "joint_3": MotorType.DM4340, # Shoulder rotation + "joint_4": MotorType.DM4340, # Elbow flex + "joint_5": MotorType.DM4310, # Wrist roll + "joint_6": MotorType.DM4310, # Wrist pitch + "joint_7": MotorType.DM4310, # Wrist rotation + "gripper": MotorType.DM4310, # Gripper +} + +# MIT control parameter ranges +MIT_KP_RANGE = (0.0, 500.0) +MIT_KD_RANGE = (0.0, 5.0) + +# CAN frame command IDs +CAN_CMD_ENABLE = 0xFC +CAN_CMD_DISABLE = 0xFD +CAN_CMD_SET_ZERO = 0xFE +CAN_CMD_REFRESH = 0xCC +CAN_CMD_QUERY_PARAM = 0x33 +CAN_CMD_WRITE_PARAM = 0x55 +CAN_CMD_SAVE_PARAM = 0xAA + +# CAN ID for parameter operations +CAN_PARAM_ID = 0x7FF diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index 01bfcf544..c6752ee96 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -22,9 +22,8 @@ import logging from copy import deepcopy from enum import Enum -from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_twos_complement, encode_twos_complement +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( AVAILABLE_BAUDRATES, MODEL_BAUDRATE_TABLE, @@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]: return data -class DynamixelMotorsBus(MotorsBus): +class DynamixelMotorsBus(SerialMotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with the motors. For more info, see the Dynamixel SDK Documentation: @@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus): for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 2ea57af12..7ce3388b6 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -17,9 +17,8 @@ from copy import deepcopy from enum import Enum from pprint import pformat -from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( FIRMWARE_MAJOR_VERSION, FIRMWARE_MINOR_VERSION, @@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802 self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50 -class FeetechMotorsBus(MotorsBus): +class FeetechMotorsBus(SerialMotorsBus): """ The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. @@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) addr, length = get_address(self.model_ctrl_table, model, "Lock") - self._write(addr, length, motor_id, 0, num_retry=num_retry) + self._write(addr, length, motor, 0, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 91bee994a..c04f718b6 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -19,6 +19,8 @@ # TODO(aliberts): Add block noqa when feature below is available # https://github.com/astral-sh/ruff/issues/3711 +from __future__ import annotations + import abc import logging from contextlib import contextmanager @@ -41,6 +43,81 @@ Value: TypeAlias = int | float logger = logging.getLogger(__name__) +class MotorsBusBase(abc.ABC): + """ + Base class for all motor bus implementations. + + This is a minimal interface that all motor buses must implement, regardless of their + communication protocol (serial, CAN, etc.). + """ + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + ): + self.port = port + self.motors = motors + self.calibration = calibration if calibration else {} + + @abc.abstractmethod + def connect(self, handshake: bool = True) -> None: + """Establish connection to the motors.""" + pass + + @abc.abstractmethod + def disconnect(self, disable_torque: bool = True) -> None: + """Disconnect from the motors.""" + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """Check if connected to the motors.""" + pass + + @abc.abstractmethod + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor.""" + pass + + @abc.abstractmethod + def write(self, data_name: str, motor: str, value: Value) -> None: + """Write a value to a single motor.""" + pass + + @abc.abstractmethod + def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]: + """Read a value from multiple motors.""" + pass + + @abc.abstractmethod + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """Write values to multiple motors.""" + pass + + @abc.abstractmethod + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + pass + + @abc.abstractmethod + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + pass + + @abc.abstractmethod + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration parameters from the motors.""" + pass + + @abc.abstractmethod + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration parameters to the motors.""" + pass + + def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: ctrl_table = model_ctrl_table.get(model) if ctrl_table is None: @@ -97,6 +174,8 @@ class Motor: id: int model: str norm_mode: MotorNormMode + motor_type_str: str | None = None + recv_id: int | None = None class PortHandler(Protocol): @@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol): def txPacket(self): ... -class MotorsBus(abc.ABC): +class SerialMotorsBus(MotorsBusBase): """ - A MotorsBus allows to efficiently read and write to the attached motors. + A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication. It represents several motors daisy-chained together and connected through a serial port. - There are currently two implementations of this abstract class: + There are currently two implementations of this class: - DynamixelMotorsBus - FeetechMotorsBus - Note: This class may evolve in the future should we add support for other types of bus. + This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.). A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: @@ -260,9 +339,7 @@ class MotorsBus(abc.ABC): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): - self.port = port - self.motors = motors - self.calibration = calibration if calibration else {} + super().__init__(port, motors, calibration) self.port_handler: PortHandler self.packet_handler: PacketHandler @@ -532,7 +609,7 @@ class MotorsBus(abc.ABC): self.set_baudrate(self.default_baudrate) @abc.abstractmethod - def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]: + def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: pass @abc.abstractmethod @@ -545,13 +622,13 @@ class MotorsBus(abc.ABC): pass @abc.abstractmethod - def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: """Disable torque on selected motors. Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM). Args: - motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a + motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a list of names or `None` to affect every registered motor. Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. @@ -1194,3 +1271,7 @@ class MotorsBus(abc.ABC): for id_, value in ids_values.items(): data = self._serialize_data(value, length) self.sync_writer.addParam(id_, data) + + +# Backward compatibility alias +MotorsBus: TypeAlias = SerialMotorsBus diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py new file mode 100644 index 000000000..55de74724 --- /dev/null +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms). + +Examples: + +Setup CAN interfaces with CAN FD: +```shell +lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3 +``` + +Test motors on a single interface: +```shell +lerobot-setup-can --mode=test --interfaces=can0 +``` + +Test motors on all interfaces: +```shell +lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3 +``` + +Speed test: +```shell +lerobot-setup-can --mode=speed --interfaces=can0 +``` +""" + +import subprocess +import sys +import time +from dataclasses import dataclass, field + +import draccus + +from lerobot.utils.import_utils import is_package_available + +MOTOR_NAMES = { + 0x01: "joint_1", + 0x02: "joint_2", + 0x03: "joint_3", + 0x04: "joint_4", + 0x05: "joint_5", + 0x06: "joint_6", + 0x07: "joint_7", + 0x08: "gripper", +} + + +@dataclass +class CANSetupConfig: + mode: str = "test" + interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3" + bitrate: int = 1000000 + data_bitrate: int = 5000000 + use_fd: bool = True + motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09))) + timeout: float = 1.0 + speed_iterations: int = 100 + + def get_interfaces(self) -> list[str]: + return [i.strip() for i in self.interfaces.split(",") if i.strip()] + + +def check_interface_status(interface: str) -> tuple[bool, str, bool]: + """Check if CAN interface is UP and configured.""" + try: + result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + return False, "Interface not found", False + + output = result.stdout + is_up = "UP" in output + is_fd = "fd on" in output.lower() or "canfd" in output.lower() + status = "UP" if is_up else "DOWN" + if is_fd: + status += " (CAN FD)" + + return is_up, status, is_fd + except FileNotFoundError: + return False, "ip command not found", False + + +def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool: + """Configure a CAN interface.""" + try: + subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607 + + cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)] + if use_fd: + cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"]) + + result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + print(f" ✗ Failed to configure: {result.stderr}") + return False + + result = subprocess.run( # nosec B607 + ["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True + ) + if result.returncode != 0: + print(f" ✗ Failed to bring up: {result.stderr}") + return False + + return True + except Exception as e: + print(f" ✗ Error: {e}") + return False + + +def test_motor(bus, motor_id: int, timeout: float, use_fd: bool): + """Test a single motor and return responses.""" + import can + + enable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=use_fd, + ) + + try: + bus.send(enable_msg) + except Exception as e: + return None, f"Send error: {e}" + + responses = [] + start_time = time.time() + + while time.time() - start_time < timeout: + msg = bus.recv(timeout=0.1) + if msg: + responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False))) + + disable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD], + is_extended_id=False, + is_fd=use_fd, + ) + try: + bus.send(disable_msg) + except Exception: + print(f"Error sending message to motor 0x{motor_id:02X}") + + return responses, None + + +def test_interface(cfg: CANSetupConfig, interface: str): + """Test all motors on a CAN interface.""" + import can + + is_up, status, _ = check_interface_status(interface) + print(f"\n{interface}: {status}") + + if not is_up: + print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}") + return {} + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return {} + + results = {} + try: + while bus.recv(timeout=0.01): + pass + + for motor_id in cfg.motor_ids: + motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}") + responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd) + + if error: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}") + results[motor_id] = {"found": False, "error": error} + elif responses: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND") + for resp_id, data, is_fd in responses: + fd_flag = " [FD]" if is_fd else "" + print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}") + results[motor_id] = {"found": True, "responses": responses} + else: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response") + results[motor_id] = {"found": False} + + time.sleep(0.05) + finally: + bus.shutdown() + + found = sum(1 for r in results.values() if r.get("found")) + print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found") + return results + + +def speed_test(cfg: CANSetupConfig, interface: str): + """Test communication speed with motors.""" + import can + + is_up, status, _ = check_interface_status(interface) + if not is_up: + print(f"{interface}: {status} - skipping") + return + + print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...") + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return + + responding_motor = None + for motor_id in cfg.motor_ids: + responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd) + if responses: + responding_motor = motor_id + break + + if not responding_motor: + print(" ✗ No responding motors found") + bus.shutdown() + return + + print(f" Testing with motor 0x{responding_motor:02X}...") + latencies = [] + + for _ in range(cfg.speed_iterations): + start = time.perf_counter() + msg = can.Message( + arbitration_id=responding_motor, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=cfg.use_fd, + ) + bus.send(msg) + resp = bus.recv(timeout=0.1) + if resp: + latencies.append((time.perf_counter() - start) * 1000) + + bus.shutdown() + + if latencies: + avg_latency = sum(latencies) / len(latencies) + hz = 1000.0 / avg_latency if avg_latency > 0 else 0 + print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}") + print(f" ✓ Avg latency: {avg_latency:.2f} ms") + print(f" ✓ Max frequency: {hz:.1f} Hz") + else: + print(" ✗ No successful responses") + + +def run_setup(cfg: CANSetupConfig): + """Setup CAN interfaces.""" + print("=" * 50) + print("CAN Interface Setup") + print("=" * 50) + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps") + if cfg.use_fd: + print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps") + print() + + interfaces = cfg.get_interfaces() + for interface in interfaces: + print(f"Configuring {interface}...") + if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd): + is_up, status, _ = check_interface_status(interface) + print(f" ✓ {interface}: {status}") + else: + print(f" ✗ {interface}: Failed") + + print("\nSetup complete!") + print("\nNext: Test motors with:") + print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}") + + +def run_test(cfg: CANSetupConfig): + """Test motors on CAN interfaces.""" + print("=" * 50) + print("CAN Motor Test") + print("=" * 50) + print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}") + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print() + + interfaces = cfg.get_interfaces() + all_results = {} + for interface in interfaces: + all_results[interface] = test_interface(cfg, interface) + + total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values()) + + print("\n" + "=" * 50) + print("Summary") + print("=" * 50) + print(f"Total motors found: {total_found}") + + if total_found == 0: + print("\n⚠ No motors found! Check:") + print(" 1. Motors are powered (24V)") + print(" 2. CAN wiring (CANH, CANL, GND)") + print(" 3. Motor timeout parameter > 0 (use Damiao tools)") + print(" 4. 120Ω termination at both cable ends") + print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}") + + +def run_speed(cfg: CANSetupConfig): + """Run speed tests on CAN interfaces.""" + print("=" * 50) + print("CAN Speed Test") + print("=" * 50) + + for interface in cfg.get_interfaces(): + speed_test(cfg, interface) + + +@draccus.wrap() +def setup_can(cfg: CANSetupConfig): + if not is_package_available("can"): + print("Error: python-can not installed. Install with: pip install python-can") + sys.exit(1) + + if cfg.mode == "setup": + run_setup(cfg) + elif cfg.mode == "test": + run_test(cfg) + elif cfg.mode == "speed": + run_speed(cfg) + else: + print(f"Unknown mode: {cfg.mode}") + print("Available modes: setup, test, speed") + sys.exit(1) + + +def main(): + setup_can() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index a499b96c7..c33a73589 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -73,6 +73,7 @@ _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") _reachy2_sdk_available = is_package_available("reachy2_sdk") +_can_available = is_package_available("python-can", "can") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/motors/test_damiao.py b/tests/motors/test_damiao.py new file mode 100644 index 000000000..7ce1af34f --- /dev/null +++ b/tests/motors/test_damiao.py @@ -0,0 +1,66 @@ +"""Minimal test script for Damiao motor with ID 3.""" + +import pytest + +from lerobot.utils.import_utils import _can_available + +if not _can_available: + pytest.skip("python-can not available", allow_module_level=True) + +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + + +@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface") +def test_damiao_motor(): + motors = { + "joint_3": Motor( + id=0x03, + model="damiao", + norm_mode="degrees", + motor_type_str="dm4310", + recv_id=0x13, + ), + } + + bus = DamiaoMotorsBus(port="can0", motors=motors) + + try: + print("Connecting...") + bus.connect() + print("✓ Connected") + + print("Enabling torque...") + bus.enable_torque() + print("✓ Torque enabled") + + print("Reading all states...") + states = bus.sync_read_all_states() + print(f"✓ States: {states}") + + print("Reading position...") + positions = bus.sync_read("Present_Position") + print(f"✓ Position: {positions}") + + print("Testing MIT control batch...") + current_pos = states["joint_3"]["position"] + commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)} + bus._mit_control_batch(commands) + print("✓ MIT control batch sent") + + print("Disabling torque...") + bus.disable_torque() + print("✓ Torque disabled") + + print("Setting zero position...") + bus.set_zero_position() + print("✓ Zero position set") + + finally: + print("Disconnecting...") + bus.disconnect(disable_torque=True) + print("✓ Disconnected") + + +if __name__ == "__main__": + test_damiao_motor() From 0c0c171d3543fafe587e0ac1768dd033aaf12ebf Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:33:45 +0100 Subject: [PATCH 03/43] Add robot images to docs (#2862) * Add robot images to docs * increase img size * remove img so100 --- docs/source/earthrover_mini_plus.mdx | 6 ++++++ docs/source/lekiwi.mdx | 6 ++++++ docs/source/so101.mdx | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index e3ffa6b32..d8083336a 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -1,5 +1,11 @@ # EarthRover Mini Plus +EarthRover Mini Plus + The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models. ## What You Need diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 511521580..b339225d8 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -1,5 +1,11 @@ # LeKiwi +LeKiwi + In the steps below, we explain how to assemble the LeKiwi mobile robot. ## Source the parts diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index cf882b373..7c9df588a 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -1,5 +1,18 @@ # SO-101 +
+ SO-101 + SO-101 +
+ In the steps below, we explain how to assemble our flagship robot, the SO-101. ## Source the parts From f6b1c39b785af0f2f78899f5de6e008f3295e594 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:31:53 +0000 Subject: [PATCH 04/43] docs: update libero (#2857) * update libero docs * Update docs/source/libero.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari --------- Signed-off-by: Jade Choghari Co-authored-by: Jade Choghari Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/libero.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 3617f3b25..def974531 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -42,6 +42,7 @@ lerobot-eval \ ``` - `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.). +- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite. - `--eval.batch_size` controls how many environments run in parallel. - `--eval.n_episodes` sets how many episodes to run in total. From 736b43f3cfb5db2450fa787a45f645e1309caa00 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 28 Jan 2026 13:31:27 +0100 Subject: [PATCH 05/43] Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861) * Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies --- src/lerobot/datasets/aggregate.py | 100 ++++++++++++++++++++++++------ tests/datasets/test_aggregate.py | 89 ++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 18 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 94ffe602e..7020545d2 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -116,6 +116,9 @@ def update_meta_data( Adjusts all indices and timestamps to account for previously aggregated data and videos in the destination dataset. + For data file indices, uses the 'src_to_dst' mapping from aggregate_data() + to correctly map source file indices to their destination locations. + Args: df: DataFrame containing the metadata to be updated. dst_meta: Destination dataset metadata. @@ -129,8 +132,50 @@ def update_meta_data( df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] - df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] - df["data/file_index"] = df["data/file_index"] + data_idx["file"] + + # Update data file indices using source-to-destination mapping + # This is critical for handling datasets that are already results of a merge + data_src_to_dst = data_idx.get("src_to_dst", {}) + if data_src_to_dst: + # Store original indices for lookup + df["_orig_data_chunk"] = df["data/chunk_index"].copy() + df["_orig_data_file"] = df["data/file_index"].copy() + + # Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file) + # This is much faster than per-row iteration for large metadata tables + mapping_index = pd.MultiIndex.from_tuples( + list(data_src_to_dst.keys()), + names=["chunk_index", "file_index"], + ) + mapping_values = list(data_src_to_dst.values()) + mapping_df = pd.DataFrame( + mapping_values, + index=mapping_index, + columns=["dst_chunk", "dst_file"], + ) + + # Construct a MultiIndex for each row based on original data indices + row_index = pd.MultiIndex.from_arrays( + [df["_orig_data_chunk"], df["_orig_data_file"]], + names=["chunk_index", "file_index"], + ) + + # Align mapping to rows; missing keys fall back to the default destination + reindexed = mapping_df.reindex(row_index) + reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna( + {"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]} + ) + + # Assign mapped destination indices back to the DataFrame + df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy() + df["data/file_index"] = reindexed["dst_file"].to_numpy() + + # Clean up temporary columns + df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"]) + else: + # Fallback to simple offset (backward compatibility for single-file sources) + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): # Store original video file indices before updating orig_chunk_col = f"videos/{key}/chunk_index" @@ -146,8 +191,7 @@ def update_meta_data( if src_to_dst: # Map each episode to its correct destination file and apply offset for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) # Get destination chunk/file for this source file dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"])) @@ -163,8 +207,7 @@ def update_meta_data( df[orig_chunk_col] = video_idx["chunk"] df[orig_file_col] = video_idx["file"] for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) offset = src_to_offset.get(src_key, 0) df.at[idx, f"videos/{key}/from_timestamp"] += offset df.at[idx, f"videos/{key}/to_timestamp"] += offset @@ -262,6 +305,10 @@ def aggregate_datasets( meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + # Clear the src_to_dst mapping after processing each source dataset + # to avoid interference between different source datasets + data_idx.pop("src_to_dst", None) + dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu dst_file_durations = video_idx["dst_file_durations"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: - # Convert to Python int to ensure consistent dict keys - src_chunk_idx = int(src_chunk_idx) - src_file_idx = int(src_file_idx) - src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, chunk_index=src_chunk_idx, @@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si Reads source data files, updates indices to match the aggregated dataset, and writes them to the destination with proper file rotation. + Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file) + which is critical for correctly updating episode metadata when source datasets + have multiple data files (e.g., from a previous merge operation). + Args: src_meta: Source dataset metadata. dst_meta: Destination dataset metadata. data_idx: Dictionary tracking data chunk and file indices. + data_files_size_in_mb: Maximum size for data files in MB. + chunk_size: Maximum number of files per chunk. Returns: dict: Updated data_idx with current chunk and file indices. @@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si # retrieve features schema for proper image typing in parquet hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None + # Track source to destination file mapping for metadata update + # This is critical for handling datasets that are already results of a merge + src_to_dst: dict[tuple[int, int], tuple[int, int]] = {} + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx @@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - data_idx = append_or_create_parquet_file( + # Write data and get the actual destination file it was written to + # This avoids duplicating the rotation logic here + data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file( df, src_path, data_idx, @@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si hf_features=hf_features, ) + # Record the mapping from source to actual destination + src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file) + + # Add the mapping to data_idx for use in metadata update + data_idx["src_to_dst"] = src_to_dst + return data_idx @@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - meta_idx = append_or_create_parquet_file( + meta_idx, _ = append_or_create_parquet_file( df, src_path, meta_idx, @@ -501,7 +562,7 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, hf_features: datasets.Features | None = None, -): +) -> tuple[dict[str, int], tuple[int, int]]: """Appends data to an existing parquet file or creates a new one based on size constraints. Manages file rotation when size limits are exceeded to prevent individual files @@ -519,9 +580,11 @@ def append_or_create_parquet_file( hf_features: Optional HuggingFace Features schema for proper image typing. Returns: - dict: Updated index dictionary with current chunk and file indices. + tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict + and (dst_chunk, dst_file) is the actual destination file the data was written to. """ - dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) @@ -529,14 +592,15 @@ def append_or_create_parquet_file( to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) - return idx + return idx, (dst_chunk, dst_file) src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) if dst_size + src_size >= max_mb: idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) - new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) new_path.parent.mkdir(parents=True, exist_ok=True) final_df = df target_path = new_path @@ -555,7 +619,7 @@ def append_or_create_parquet_file( else: final_df.to_parquet(target_path) - return idx + return idx, (dst_chunk, dst_file) def finalize_aggregation(aggr_meta, all_metadata): diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 031c29d60..3609bac24 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): + """Regression test for aggregating a dataset that is itself a result of a previous merge. + + This test reproduces the bug where merging datasets with multiple parquet files + (e.g., from a previous merge with file rotation) would cause FileNotFoundError + because metadata file indices were incorrectly preserved instead of being mapped + to their actual destination files. + + The fix adds src_to_dst tracking in aggregate_data() to correctly map source + file indices to destination file indices. + """ + # Step 1: Create datasets A and B + ds_a = lerobot_dataset_factory( + root=tmp_path / "ds_a", + repo_id=f"{DUMMY_REPO_ID}_a", + total_episodes=4, + total_frames=200, + ) + ds_b = lerobot_dataset_factory( + root=tmp_path / "ds_b", + repo_id=f"{DUMMY_REPO_ID}_b", + total_episodes=4, + total_frames=200, + ) + + # Step 2: Merge A+B into AB with small file size to force multiple files + aggregate_datasets( + repo_ids=[ds_a.repo_id, ds_b.repo_id], + roots=[ds_a.root, ds_b.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_ab", + aggr_root=tmp_path / "ds_ab", + data_files_size_in_mb=0.01, # Force file rotation + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_ab") + ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab") + + # Verify AB has multiple data files (file rotation occurred) + ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet")) + assert len(ab_data_files) > 1, "First merge should create multiple parquet files" + + # Step 3: Create dataset C + ds_c = lerobot_dataset_factory( + root=tmp_path / "ds_c", + repo_id=f"{DUMMY_REPO_ID}_c", + total_episodes=2, + total_frames=100, + ) + + # Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED + aggregate_datasets( + repo_ids=[ds_ab.repo_id, ds_c.repo_id], + roots=[ds_ab.root, ds_c.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_abc", + aggr_root=tmp_path / "ds_abc", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_abc") + ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc") + + # Step 5: Verify all data files referenced in metadata actually exist + for ep_idx in range(ds_abc.num_episodes): + data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx) + assert data_file_path.exists(), ( + f"Episode {ep_idx} references non-existent file: {data_file_path}\n" + "This indicates the src_to_dst mapping fix is not working correctly." + ) + + # Step 6: Verify we can iterate through the entire dataset without FileNotFoundError + expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes + expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames + + assert ds_abc.num_episodes == expected_episodes + assert ds_abc.num_frames == expected_frames + + # This would raise FileNotFoundError before the fix + assert_dataset_iteration_works(ds_abc) From bf337e716da18054e463003fa37f47df2aa9bfe3 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 14:28:51 +0100 Subject: [PATCH 06/43] feat(robots): add OpenArm robot & teleoperator (#2795) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * ignore linter --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/_toctree.yml | 2 + docs/source/openarm.mdx | 258 +++++++++++++ pyproject.toml | 1 + src/lerobot/motors/damiao/damiao.py | 51 ++- src/lerobot/processor/hil_processor.py | 12 +- .../robots/openarm_follower/__init__.py | 20 + .../config_openarm_follower.py | 117 ++++++ .../openarm_follower/openarm_follower.py | 348 ++++++++++++++++++ src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/openarm_leader/__init__.py | 20 + .../openarm_leader/config_openarm_leader.py | 70 ++++ .../openarm_leader/openarm_leader.py | 225 +++++++++++ src/lerobot/teleoperators/utils.py | 14 +- 18 files changed, 1129 insertions(+), 22 deletions(-) create mode 100644 docs/source/openarm.mdx create mode 100644 src/lerobot/robots/openarm_follower/__init__.py create mode 100644 src/lerobot/robots/openarm_follower/config_openarm_follower.py create mode 100644 src/lerobot/robots/openarm_follower/openarm_follower.py create mode 100644 src/lerobot/teleoperators/openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py create mode 100644 src/lerobot/teleoperators/openarm_leader/openarm_leader.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f86dd11c7..eb97117af 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -101,6 +101,8 @@ title: Earth Rover Mini - local: omx title: OMX + - local: openarm + title: OpenArm title: "Robots" - sections: - local: phone_teleop diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx new file mode 100644 index 000000000..661808749 --- /dev/null +++ b/docs/source/openarm.mdx @@ -0,0 +1,258 @@ +# OpenArm + +[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment. + +To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev). + +## What's Unique? + +- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation. + +- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks. + +- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use. + +- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided. + +- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs. + +## Platform Requirements + + + **Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter + does not have macOS drivers and has not been tested on Windows. + + +## Safety Guide + +Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points: + +- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps +- **Safe distance**: Keep body parts and objects outside the range of motion during operation +- **Protective equipment**: Always wear safety goggles; use additional PPE as needed +- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm) +- **Emergency stop**: Know the location and operation of the emergency stop device +- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage + +## Hardware Setup + +Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for: + +- Bill of materials and sourcing +- 3D printing instructions +- Mechanical assembly +- Electrical wiring + +The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm). + +## CAN Bus Setup + +OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface. + +Quick setup: + +```bash +# Setup CAN interfaces +lerobot-setup-can --mode=setup --interfaces=can0,can1 + +# Test motor communication +lerobot-setup-can --mode=test --interfaces=can0,can1 +``` + +## Install LeRobot 🤗 + +Follow our [Installation Guide](./installation), then install the Damiao motor support: + +```bash +pip install -e ".[damiao]" +``` + +## Usage + +### Follower Arm (Robot) + + + + +```bash +lerobot-calibrate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_openarm_follower +``` + + + + +```python +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +config = OpenArmFollowerConfig( + port="can0", + side="right", # or "left" for left arm + id="my_openarm_follower", +) + +follower = OpenArmFollower(config) +follower.connect() + +# Read current state +obs = follower.get_observation() +print(obs) + +# Send action (position in degrees) +action = { + "joint_1.pos": 0.0, + "joint_2.pos": 0.0, + "joint_3.pos": 0.0, + "joint_4.pos": 45.0, + "joint_5.pos": 0.0, + "joint_6.pos": 0.0, + "joint_7.pos": 0.0, + "gripper.pos": 0.0, +} +follower.send_action(action) + +follower.disconnect() +``` + + + + +### Leader Arm (Teleoperator) + +The leader arm is used for teleoperation - manually moving it to control the follower arm. + + + + +```bash +lerobot-calibrate \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_openarm_leader +``` + + + + +```python +from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig + +config = OpenArmLeaderConfig( + port="can1", + id="my_openarm_leader", + manual_control=True, # Disable torque for manual movement +) + +leader = OpenArmLeader(config) +leader.connect() + +# Read current position (as action to send to follower) +action = leader.get_action() +print(action) + +leader.disconnect() +``` + + + + +### Teleoperation + +To teleoperate OpenArm with leader-follower control: + +```bash +lerobot-teleoperate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader +``` + +### Recording Data + +To record a dataset during teleoperation: + +```bash +lerobot-record \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader \ + --repo-id=my_hf_username/my_openarm_dataset \ + --fps=30 \ + --num-episodes=10 +``` + +## Configuration Options + +### Follower Configuration + +| Parameter | Default | Description | +| --------------------- | --------- | ---------------------------------------------------------- | +| `port` | - | CAN interface (e.g., `can0`) | +| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | +| `max_relative_target` | `None` | Safety limit for relative target positions | +| `position_kp` | Per-joint | Position control proportional gains | +| `position_kd` | Per-joint | Position control derivative gains | + +### Leader Configuration + +| Parameter | Default | Description | +| ------------------ | --------- | ----------------------------------- | +| `port` | - | CAN interface (e.g., `can1`) | +| `manual_control` | `True` | Disable torque for manual movement | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +OpenArm uses Damiao motors with the following default configuration: + +| Joint | Motor Type | Send ID | Recv ID | +| --------------------------- | ---------- | ------- | ------- | +| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 | +| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 | +| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 | +| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 | +| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 | +| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 | +| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 | +| gripper | DM4310 | 0x08 | 0x18 | + +## Troubleshooting + +### No Response from Motors + +1. Check power supply connections +2. Verify CAN wiring (CAN-H, CAN-L, GND) +3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0` +4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details + +### CAN Interface Not Found + +Ensure the CAN interface is configured: + +```bash +ip link show can0 +``` + +## Resources + +- [OpenArm Website](https://openarm.dev) +- [OpenArm Documentation](https://docs.openarm.dev) +- [OpenArm GitHub](https://github.com/enactic/openarm) +- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide) +- [Damiao Motors and CAN Bus](./damiao) diff --git a/pyproject.toml b/pyproject.toml index 27126f855..ea2dfb4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] damiao = ["python-can>=4.2.0,<5.0.0"] # Robots +openarms = ["lerobot[damiao]"] gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index dd0213fc3..c79f8d17e 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -28,8 +28,11 @@ from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: import can else: - can.Message = object - can.interface = None + + class can: # noqa: N801 + Message = object + interface = None + import numpy as np @@ -206,11 +209,31 @@ class DamiaoMotorsBus(MotorsBusBase): Raises ConnectionError if any motor fails to respond. """ logger.info("Starting handshake with motors...") - missing_motors = [] + # Drain any pending messages + while self.canbus.recv(timeout=0.01): + pass + + missing_motors = [] for motor_name in self.motors: - msg = self._refresh_motor(motor_name) - if msg is None: + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + + # Send enable command + data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + self.canbus.send(msg) + + # Wait for response with longer timeout + response = None + start_time = time.time() + while time.time() - start_time < 0.1: + response = self.canbus.recv(timeout=0.1) + if response and response.arbitration_id == recv_id: + break + response = None + + if response is None: missing_motors.append(motor_name) else: self._process_response(motor_name, msg) @@ -259,7 +282,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -317,7 +340,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_id = self._get_motor_id(motor) recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -439,7 +462,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id = self._get_motor_recv_id(motor) @@ -472,7 +495,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name @@ -637,10 +660,10 @@ class DamiaoMotorsBus(MotorsBusBase): for motor in motors: motor_id = self._get_motor_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) - # Small delay to reduce bus congestion if necessary, though removed in sync_read previously - # precise_sleep(PRECISE_SLEEP_SEC) # Collect responses expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] @@ -676,7 +699,9 @@ class DamiaoMotorsBus(MotorsBusBase): kd = self._gains[motor]["kd"] data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) precise_sleep(PRECISE_TIMEOUT_SEC) diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index f0dbac9c3..6d44ed8cb 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -18,16 +18,18 @@ import math import time from dataclasses import dataclass -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +if TYPE_CHECKING: + from lerobot.teleoperators.teleoperator import Teleoperator + from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ( ComplementaryDataProcessorStep, @@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol): # Type variable constrained to Teleoperator subclasses that also implement events -TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) +TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator") -def _check_teleop_with_events(teleop: Teleoperator) -> None: +def _check_teleop_with_events(teleop: "Teleoperator") -> None: """ Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. @@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): teleop_device: The teleoperator instance to get the action from. """ - teleop_device: Teleoperator + teleop_device: "Teleoperator" def complementary_data(self, complementary_data: dict) -> dict: """ diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py new file mode 100644 index 000000000..1eb0d9fc7 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_follower import OpenArmFollowerConfig +from .openarm_follower import OpenArmFollower + +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py new file mode 100644 index 000000000..af95b6395 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + +LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-90.0, 9.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + +RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-9.0, 90.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig): + """Configuration for the OpenArms follower robot with Damiao motors.""" + + # CAN interfaces - one per arm + # arm CAN interface (e.g., "can1") + # Linux: "can0", "can1", etc. + port: str + + # side of the arm: "left" or "right". If "None" default values will be used + side: str | None = None + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Whether to disable torque when disconnecting + disable_torque_on_disconnect: bool = True + + # Safety limit for relative target positions + # Set to a positive scalar for all motors, or a dict mapping motor names to limits + max_relative_target: float | dict[str, float] | None = None + + # Camera configurations + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # MIT control parameters for position control (used in send_action) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0] + ) + position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3]) + + # Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'. + # If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety. + joint_limits: dict[str, tuple[float, float]] = field( + default_factory=lambda: { + "joint_1": (-5.0, 5.0), + "joint_2": (-5.0, 5.0), + "joint_3": (-5.0, 5.0), + "joint_4": (0.0, 5.0), + "joint_5": (-5.0, 5.0), + "joint_6": (-5.0, 5.0), + "joint_7": (-5.0, 5.0), + "gripper": (-5.0, 0.0), + } + ) diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py new file mode 100644 index 000000000..c221afd10 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction, RobotObservation +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_openarm_follower import ( + LEFT_DEFAULT_JOINTS_LIMITS, + RIGHT_DEFAULT_JOINTS_LIMITS, + OpenArmFollowerConfig, +) + +logger = logging.getLogger(__name__) + + +class OpenArmFollower(Robot): + """ + OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper. + The arm uses Damiao motors in MIT control mode. + """ + + config_class = OpenArmFollowerConfig + name = "openarm_follower" + + def __init__(self, config: OpenArmFollowerConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + if config.side is not None: + if config.side == "left": + config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS + elif config.side == "right": + config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS + else: + raise ValueError( + "config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)" + ) + else: + logger.info( + "Set config.side to either 'left' or 'right' to use pre-configured values for joint limits." + ) + logger.info(f"Values used for joint limits: {config.joint_limits}.") + + # Initialize cameras + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + """Motor features for observation and action spaces.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float # Add this + features[f"{motor}.torque"] = float # Add this + return features + + @property + def _cameras_ft(self) -> dict[str, tuple]: + """Camera features for observation space.""" + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + """Combined observation features from motors and cameras.""" + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + """Action features.""" + return self._motors_ft + + @property + def is_connected(self) -> bool: + """Check if robot is connected.""" + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the robot and optionally calibrate. + + We assume that at connection time, the arms are in a safe rest position, + and torque can be safely disabled to run calibration if needed. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + self.bus.enable_torque() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if robot is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms robot. + + The calibration procedure: + 1. Disable torque + 2. Ask user to position arms in hanging position with grippers closed + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° for safety by default for all joints") + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """Configure motors with appropriate settings.""" + # TODO(Steven, Pepijn): Slightly different from what it is happening in the leader + with self.bus.torque_disabled(): + self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + """ + Get current observation from robot including position, velocity, and torque. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle + instead of 3 separate reads. + """ + start = time.perf_counter() + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict: dict[str, Any] = {} + + states = self.bus.sync_read_all_states() + + for motor in self.bus.motors: + state = states.get(motor, {}) + obs_dict[f"{motor}.pos"] = state.get("position", 0.0) + obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0) + obs_dict[f"{motor}.torque"] = state.get("torque", 0.0) + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms") + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + """ + Send action command to robot. + + The action magnitude may be clipped based on safety limits. + + Args: + action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos") + custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0}) + custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0}) + + Returns: + The action actually sent (potentially clipped) + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Apply joint limit clipping to arm + for motor_name, position in goal_pos.items(): + if motor_name in self.config.joint_limits: + min_limit, max_limit = self.config.joint_limits[motor_name] + clipped_position = max(min_limit, min(max_limit, position)) + if clipped_position != position: + logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°") + goal_pos[motor_name] = clipped_position + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # TODO(Steven, Pepijn): Refactor writing + # Motor name to index mapping for gains + motor_index = { + "joint_1": 0, + "joint_2": 1, + "joint_3": 2, + "joint_4": 3, + "joint_5": 4, + "joint_6": 5, + "joint_7": 6, + "gripper": 7, + } + + # Use batch MIT control for arm (sends all commands, then collects responses) + commands = {} + for motor_name, position_degrees in goal_pos.items(): + idx = motor_index.get(motor_name, 0) + # Use custom gains if provided, otherwise use config defaults + if custom_kp is not None and motor_name in custom_kp: + kp = custom_kp[motor_name] + else: + kp = ( + self.config.position_kp[idx] + if isinstance(self.config.position_kp, list) + else self.config.position_kp + ) + if custom_kd is not None and motor_name in custom_kd: + kd = custom_kd[motor_name] + else: + kd = ( + self.config.position_kd[idx] + if isinstance(self.config.position_kd, list) + else self.config.position_kd + ) + commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0) + + self.bus._mit_control_batch(commands) + + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + """Disconnect from robot.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + self.bus.disconnect(self.config.disable_torque_on_disconnect) + + # Disconnect cameras + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 27abaaa86..e0c76cab3 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -60,6 +60,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .reachy2 import Reachy2Robot return Reachy2Robot(config) + elif config.type == "openarm_follower": + from .openarm_follower import OpenArmFollower + + return OpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index cbc7684d3..0f79e6aa2 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -42,6 +42,7 @@ from lerobot.robots import ( # noqa: F401 lekiwi, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 20bbc8615..d928dc5cd 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -48,6 +48,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -57,6 +58,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index f03776989..4d334f38f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -104,6 +104,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, @@ -116,6 +117,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 49c06d643..c3bc3d766 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -59,6 +59,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 18d8863d6..a415dd600 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -76,6 +76,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, ) @@ -89,6 +90,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py new file mode 100644 index 000000000..1493317fe --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_leader import OpenArmLeaderConfig +from .openarm_leader import OpenArmLeader + +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py new file mode 100644 index 000000000..c53169b0a --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig): + """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" + + # CAN interfaces - one per arm + # Arm CAN interface (e.g., "can3") + # Linux: "can0", "can1", etc. + port: str + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # Torque mode settings for manual control + # When enabled, motors have torque disabled for manual movement + manual_control: bool = True + + # TODO(Steven, Pepijn): Not used ... ? + # MIT control parameters (used when manual_control=False for torque control) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] + ) + position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py new file mode 100644 index 000000000..edf4d7090 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..teleoperator import Teleoperator +from .config_openarm_leader import OpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class OpenArmLeader(Teleoperator): + """ + OpenArm Leader/Teleoperator Arm with Damiao motors. + + This teleoperator uses CAN bus communication to read positions from + Damiao motors that are manually moved (torque disabled). + """ + + config_class = OpenArmLeaderConfig + name = "openarm_leader" + + def __init__(self, config: OpenArmLeaderConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + @property + def action_features(self) -> dict[str, type]: + """Features produced by this teleoperator.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float + features[f"{motor}.torque"] = float + return features + + @property + def feedback_features(self) -> dict[str, type]: + """Feedback features (not implemented for OpenArms).""" + return {} + + @property + def is_connected(self) -> bool: + """Check if teleoperator is connected.""" + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the teleoperator. + + For manual control, we disable torque after connecting so the + arm can be moved by hand. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if teleoperator is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms leader. + + The calibration procedure: + 1. Disable torque (if not already disabled) + 2. Ask user to position arm in zero position (hanging with gripper closed) + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° by default for all joints") + # TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """ + Configure motors for manual teleoperation. + + For manual control, we disable torque so the arm can be moved by hand. + """ + + return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + """ + Get current action from the leader arm. + + This is the main method for teleoperators - it reads the current state + of the leader arm and returns it as an action that can be sent to a follower. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle. + """ + start = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + action_dict: dict[str, Any] = {} + + # Use sync_read_all_states to get pos/vel/torque in one go + states = self.bus.sync_read_all_states() + for motor in self.bus.motors: + state = states.get(motor, {}) + action_dict[f"{motor}.pos"] = state.get("position") + action_dict[f"{motor}.vel"] = state.get("velocity") + action_dict[f"{motor}.torque"] = state.get("torque") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + + def disconnect(self) -> None: + """Disconnect from teleoperator.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + # For manual control, ensure torque is disabled before disconnecting + self.bus.disconnect(disable_torque=self.config.manual_control) + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index eec2f119c..8f6bbc787 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -13,12 +13,14 @@ # limitations under the License. from enum import Enum -from typing import cast +from typing import TYPE_CHECKING, cast from lerobot.utils.import_utils import make_device_from_device_class from .config import TeleoperatorConfig -from .teleoperator import Teleoperator + +if TYPE_CHECKING: + from .teleoperator import Teleoperator class TeleopEvents(Enum): @@ -31,7 +33,7 @@ class TeleopEvents(Enum): TERMINATE_EPISODE = "terminate_episode" -def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: +def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": # TODO(Steven): Consider just using the make_device_from_device_class for all types if config.type == "keyboard": from .keyboard import KeyboardTeleop @@ -81,8 +83,12 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .reachy2_teleoperator import Reachy2Teleoperator return Reachy2Teleoperator(config) + elif config.type == "openarm_leader": + from .openarm_leader import OpenArmLeader + + return OpenArmLeader(config) else: try: - return cast(Teleoperator, make_device_from_device_class(config)) + return cast("Teleoperator", make_device_from_device_class(config)) except Exception as e: raise ValueError(f"Error creating robot with config {config}: {e}") from e From 149628dfd5b3079a2c0f80832deac1da3e7bd287 Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:17:38 +0100 Subject: [PATCH 07/43] add g1 teleoperation (#2791) * add gravity compensation * add g1 teleoperation --------- Co-authored-by: Michel Aractingi --- docs/source/unitree_g1.mdx | 100 +++- pyproject.toml | 6 +- .../robots/unitree_g1/config_unitree_g1.py | 3 + src/lerobot/robots/unitree_g1/g1_utils.py | 2 +- .../unitree_g1/robot_kinematic_processor.py | 313 ++++++++++++ src/lerobot/robots/unitree_g1/unitree_g1.py | 19 +- src/lerobot/scripts/lerobot_calibrate.py | 1 + src/lerobot/scripts/lerobot_record.py | 3 +- src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/unitree_g1/__init__.py | 21 + .../unitree_g1/config_unitree_g1.py | 37 ++ .../teleoperators/unitree_g1/exo_calib.py | 446 ++++++++++++++++++ .../teleoperators/unitree_g1/exo_ik.py | 353 ++++++++++++++ .../teleoperators/unitree_g1/exo_serial.py | 119 +++++ .../teleoperators/unitree_g1/unitree_g1.py | 157 ++++++ src/lerobot/teleoperators/utils.py | 4 + 16 files changed, 1581 insertions(+), 5 deletions(-) create mode 100644 src/lerobot/robots/unitree_g1/robot_kinematic_processor.py create mode 100644 src/lerobot/teleoperators/unitree_g1/__init__.py create mode 100644 src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_calib.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_ik.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_serial.py create mode 100644 src/lerobot/teleoperators/unitree_g1/unitree_g1.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index e6bffdf1b..ea6bf54ad 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy. ## Running in Simulation Mode (MuJoCo) -You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. +You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI. + +### Calibrate Exoskeleton Teleoperator + +```bash +lerobot-calibrate \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo +``` + +### Teleoperate in Simulation + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset in Simulation + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) + +--- + +## Running on Real Robot + +Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot. + +### Start the Camera Server + +On the robot, start the ZMQ image server: + +```bash +python src/lerobot/cameras/zmq/image_server.py +``` + +Keep this running in a separate terminal for camera streaming during recording. + +### Teleoperate Real Robot + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset on Real Robot + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +**Note**: Update `server_address` to match your robot's camera server IP. + +Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real) + +--- ## Additional Resources diff --git a/pyproject.toml b/pyproject.toml index ea2dfb4a2..210d70b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,11 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ "pyzmq>=26.2.1,<28.0.0", - "onnxruntime>=1.16.0,<2.0.0" + "onnxruntime>=1.16.0,<2.0.0", + "pin>=3.0.0,<4.0.0", + "meshcat>=0.3.0,<0.4.0", + "matplotlib>=3.9.0,<4.0.0", + "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 0b163019d..1b81214a6 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig): # Cameras (ZMQ-based remote cameras) cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Compensates for gravity on the unitree's arms using the arm ik solver + gravity_compensation: bool = False diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index 3c41ee985..4e37bdcef 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -18,7 +18,7 @@ from enum import IntEnum # ruff: noqa: N801, N815 -NUM_MOTORS = 35 +NUM_MOTORS = 29 class G1_29_JointArmIndex(IntEnum): diff --git a/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py new file mode 100644 index 000000000..d086a9986 --- /dev/null +++ b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + +import numpy as np + +logger = logging.getLogger(__name__) +parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(parent2_dir) + + +class WeightedMovingFilter: + def __init__(self, weights, data_size=14): + self._window_size = len(weights) + self._weights = np.array(weights) + self._data_size = data_size + self._filtered_data = np.zeros(self._data_size) + self._data_queue = [] + + def _apply_filter(self): + if len(self._data_queue) < self._window_size: + return self._data_queue[-1] + + data_array = np.array(self._data_queue) + temp_filtered_data = np.zeros(self._data_size) + for i in range(self._data_size): + temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1] + + return temp_filtered_data + + def add_data(self, new_data): + assert len(new_data) == self._data_size + + if len(self._data_queue) > 0 and np.array_equal( + new_data, self._data_queue[-1] + ): # skip duplicate data + return + + if len(self._data_queue) >= self._window_size: + self._data_queue.pop(0) + + self._data_queue.append(new_data) + self._filtered_data = self._apply_filter() + + @property + def filtered_data(self): + return self._filtered_data + + +class G1_29_ArmIK: # noqa: N801 + def __init__(self, unit_test=False): + import casadi + import pinocchio as pin + from huggingface_hub import snapshot_download + from pinocchio import casadi as cpin + + self._pin = pin + np.set_printoptions(precision=5, suppress=True, linewidth=200) + + self.unit_test = unit_test + + self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco") + urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf") + mesh_dir = os.path.join(self.repo_path, "assets") + + self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir) + + self.mixed_jointsToLockIDs = [ + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_hand_thumb_0_joint", + "left_hand_thumb_1_joint", + "left_hand_thumb_2_joint", + "left_hand_middle_0_joint", + "left_hand_middle_1_joint", + "left_hand_index_0_joint", + "left_hand_index_1_joint", + "right_hand_thumb_0_joint", + "right_hand_thumb_1_joint", + "right_hand_thumb_2_joint", + "right_hand_index_0_joint", + "right_hand_index_1_joint", + "right_hand_middle_0_joint", + "right_hand_middle_1_joint", + ] + + self.reduced_robot = self.robot.buildReducedRobot( + list_of_joints_to_lock=self.mixed_jointsToLockIDs, + reference_configuration=np.array([0.0] * self.robot.model.nq), + ) + + # Arm joint names in G1 motor order (G1_29_JointArmIndex) + self._arm_joint_names_g1 = [ + "left_shoulder_pitch_joint", + "left_shoulder_roll_joint", + "left_shoulder_yaw_joint", + "left_elbow_joint", + "left_wrist_roll_joint", + "left_wrist_pitch_joint", + "left_wrist_yaw_joint", + "right_shoulder_pitch_joint", + "right_shoulder_roll_joint", + "right_shoulder_yaw_joint", + "right_elbow_joint", + "right_wrist_roll_joint", + "right_wrist_pitch_joint", + "right_wrist_yaw_joint", + ] + # Pinocchio uses its own joint order in q; build index mapping. + self._arm_joint_names_pin = sorted( + self._arm_joint_names_g1, + key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)], + ) + logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}") + self._arm_reorder_g1_to_pin = [ + self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin + ] + # Inverse mapping to return tau in G1 motor order. + self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "L_ee", + self.reduced_robot.model.getJointId("left_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "R_ee", + self.reduced_robot.model.getJointId("right_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + # Creating Casadi models and data for symbolic computing + self.cmodel = cpin.Model(self.reduced_robot.model) + self.cdata = self.cmodel.createData() + + # Creating symbolic variables + self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1) + self.cTf_l = casadi.SX.sym("tf_l", 4, 4) + self.cTf_r = casadi.SX.sym("tf_r", 4, 4) + cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + + # Get the hand joint ID and define the error function + self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee") + self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee") + + self.translational_error = casadi.Function( + "translational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3], + self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3], + ) + ], + ) + self.rotational_error = casadi.Function( + "rotational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T), + cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T), + ) + ], + ) + + # Defining the optimization problem + self.opti = casadi.Opti() + self.var_q = self.opti.variable(self.reduced_robot.model.nq) + self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth + self.param_tf_l = self.opti.parameter(4, 4) + self.param_tf_r = self.opti.parameter(4, 4) + self.translational_cost = casadi.sumsqr( + self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.rotation_cost = casadi.sumsqr( + self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.regularization_cost = casadi.sumsqr(self.var_q) + self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last) + + # Setting optimization constraints and goals + self.opti.subject_to( + self.opti.bounded( + self.reduced_robot.model.lowerPositionLimit, + self.var_q, + self.reduced_robot.model.upperPositionLimit, + ) + ) + self.opti.minimize( + 50 * self.translational_cost + + self.rotation_cost + + 0.02 * self.regularization_cost + + 0.1 * self.smooth_cost + ) + + opts = { + "ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6}, + "print_time": False, # print or not + "calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F + } + self.opti.solver("ipopt", opts) + + self.init_data = np.zeros(self.reduced_robot.model.nq) + self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14) + + def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + if current_lr_arm_motor_q is not None: + self.init_data = current_lr_arm_motor_q + self.opti.set_initial(self.var_q, self.init_data) + + self.opti.set_value(self.param_tf_l, left_wrist) + self.opti.set_value(self.param_tf_r, right_wrist) + self.opti.set_value(self.var_q_last, self.init_data) # for smooth + + try: + self.opti.solve() + + sol_q = self.opti.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + return sol_q, sol_tauff + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + + sol_q = self.opti.debug.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + logger.error( + f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}" + ) + + return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + + def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + try: + q_g1 = np.array(current_lr_arm_motor_q, dtype=float) + if q_g1.shape[0] != len(self._arm_joint_names_g1): + raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}") + q_pin = q_g1[self._arm_reorder_g1_to_pin] + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + q_pin, + np.zeros(self.reduced_robot.model.nv), + np.zeros(self.reduced_robot.model.nv), + ) + return sol_tauff[self._arm_reorder_pin_to_g1] + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + return np.zeros(self.reduced_robot.model.nv) diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index fa6e0da85..01b4f330e 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -27,7 +27,8 @@ import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.processor import RobotAction, RobotObservation -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK from ..robot import Robot from .config_unitree_g1 import UnitreeG1Config @@ -127,6 +128,8 @@ class UnitreeG1(Robot): self.subscribe_thread = None self.remote_controller = self.RemoteController() + self.arm_ik = G1_29_ArmIK() + def _subscribe_motor_state(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() @@ -361,6 +364,20 @@ class UnitreeG1(Robot): self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] self.msg.motor_cmd[motor.value].tau = 0 + if self.config.gravity_compensation: + # Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13) + action_np = np.zeros(14) + arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15 + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + action_np[local_idx] = self.msg.motor_cmd[joint.value].q + tau = self.arm_ik.solve_tau(action_np) + + # Apply tau back to motor commands + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + self.msg.motor_cmd[joint.value].tau = tau[local_idx] + self.msg.crc = self.crc.Crc(self.msg) self.lowcmd_publisher.Write(self.msg) return action diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 0f79e6aa2..2fa1b2a03 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -55,6 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401 omx_leader, openarm_leader, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 4d334f38f..d621189e8 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -107,7 +107,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, - unitree_g1, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -120,6 +120,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index a415dd600..958bd00ef 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -79,6 +79,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -93,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/teleoperators/unitree_g1/__init__.py b/src/lerobot/teleoperators/unitree_g1/__init__.py new file mode 100644 index 000000000..45955a0e2 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig +from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm +from .unitree_g1 import UnitreeG1Teleoperator diff --git a/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py new file mode 100644 index 000000000..66c4e7f31 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@dataclass +class ExoskeletonArmPortConfig: + """Serial port configuration for individual exoskeleton arm.""" + + port: str = "" + baud_rate: int = 115200 + + +@TeleoperatorConfig.register_subclass("unitree_g1") +@dataclass +class UnitreeG1TeleoperatorConfig(TeleoperatorConfig): + left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + + # Frozen joints (comma-separated joint names that won't be moved by IK) + frozen_joints: str = "" diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py new file mode 100644 index 000000000..2927a1b55 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module handles calibration of hall effect sensors used in the exoskeleton. +Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse +as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle, +and calculate arctan2 of the unit circle to get the joint angle. +We then store the ellipse parameters and the zero offset for each joint to be used at runtime. +""" + +import json +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import serial + +logger = logging.getLogger(__name__) + + +# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw +JOINTS = { + "shoulder_pitch": (0, 1), + "shoulder_yaw": (2, 3), + "shoulder_roll": (4, 5), + "elbow_flex": (6, 7), + "wrist_roll": (14, 15), +} + + +@dataclass +class ExoskeletonJointCalibration: + name: str # joint name + center_fit: list[float] # center of the ellipse + T: list[list[float]] # 2x2 transformation matrix + zero_offset: float = 0.0 # angle at neutral pose + + +@dataclass +class ExoskeletonCalibration: + """Full calibration data for an exoskeleton arm.""" + + version: int = 2 + side: str = "" + adc_max: int = 2**12 - 1 + joints: list[ExoskeletonJointCalibration] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "version": self.version, + "side": self.side, + "adc_max": self.adc_max, + "joints": [ + { + "name": j.name, + "center_fit": j.center_fit, + "T": j.T, + "zero_offset": j.zero_offset, + } + for j in self.joints + ], + } + + @classmethod + def from_dict(cls, data: dict) -> "ExoskeletonCalibration": + joints = [ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in data.get("joints", []) + ] + return cls( + version=data.get("version", 2), + side=data.get("side", ""), + adc_max=data.get("adc_max", 2**12 - 1), + joints=joints, + ) + + +@dataclass(frozen=True) +class CalibParams: + fit_every: float = 0.15 + min_fit_points: int = 60 + fit_window: int = 900 + max_fit_points: int = 300 + trim_low: float = 0.05 + trim_high: float = 0.95 + median_window: int = 5 + history: int = 3500 + draw_hz: float = 120.0 + sample_count: int = 50 + + +def normalize_angle(angle: float) -> float: + while angle > np.pi: + angle -= 2 * np.pi + while angle < -np.pi: + angle += 2 * np.pi + return angle + + +def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]: + """ + Applies calibration to each joint: raw → centered → ellipse-to-circle → angle. + """ + pair = JOINTS[j.name] + s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos + p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values + z = np.asarray(j.T) @ ( + p - np.asarray(j.center_fit) + ) # center the ellipse and invert the transformation matrix to get unit circle coords + ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset + return z, normalize_angle(-ang) # ensure range is [-pi, pi] + + +def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]: + """Convert raw sensor readings to joint angles using calibration.""" + return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints} + + +def run_exo_calibration( + ser: serial.Serial, + side: str, + save_path: Path, + params: CalibParams | None = None, +) -> ExoskeletonCalibration: + """ + Run interactive calibration for an exoskeleton arm. + """ + try: + import cv2 + import matplotlib.pyplot as plt + except ImportError as e: + raise ImportError( + "Calibration requires matplotlib and opencv-python. " + "Install with: pip install matplotlib opencv-python" + ) from e + + from .exo_serial import read_raw_from_serial + + params = params or CalibParams() + joint_list = list(JOINTS.items()) # Convert dict to list for indexing + logger.info(f"Starting calibration for {side} exoskeleton arm") + + def running_median(win: deque) -> float: + return float(np.median(np.fromiter(win, dtype=float))) + + def read_joint_point(raw16: list[int], pair: tuple[int, int]): + s, c = raw16[pair[0]], raw16[pair[1]] + return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c) + + def select_fit_subset(xs, ys): + """Select and filter points for ellipse fitting. Trims outliers by radius and downsamples.""" + n = min(params.fit_window, len(xs)) + if n <= 0: + return None, None + x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples + y = np.asarray(list(ys)[-n:], dtype=float) + r = np.sqrt(x * x + y * y) # radius from origin + if len(r) >= 20: + lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds + keep = (r >= lo) & (r <= hi) + x, y = x[keep], y[keep] # remove outliers + if len(x) > params.max_fit_points: + idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly + x, y = x[idx], y[idx] + return x, y + + def fit_ellipse_opencv(x, y): + """Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline.""" + x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float) + if len(x) < 5: + return None + pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2) + try: + (xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees + except cv2.error: + return None + a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes + phi = np.deg2rad(float(angle_deg)) # to rad + if b > a: # ensure major axis is a + a, b = b, a + phi += np.pi / 2.0 + if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6: + return None + cp, sp = float(np.cos(phi)), float(np.sin(phi)) # + rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix + center = np.array([float(xc), float(yc)], dtype=float) # offset vector + tt = np.linspace(0, 2 * np.pi, 360) + outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz + return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]} + + # Setup matplotlib + plt.ion() + fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6)) + ax0.set_xlabel("cos - center") + ax0.set_ylabel("sin - center") + ax0.grid(True, alpha=0.25) + ax0.set_aspect("equal", adjustable="box") + ax1.set_title("Unit circle + angle") + ax1.set_xlabel("x") + ax1.set_ylabel("y") + ax1.grid(True, alpha=0.25) + ax1.set_aspect("equal", adjustable="box") + tt = np.linspace(0, 2 * np.pi, 360) + ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1) + ax0.set_xlim(-2200, 2200) + ax0.set_ylim(-2200, 2200) + ax1.set_xlim(-1.4, 1.4) + ax1.set_ylim(-1.4, 1.4) + + sc0 = ax0.scatter([], [], s=6, animated=True) + (ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True) + sc1 = ax1.scatter([], [], s=6, animated=True) + (radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True) + angle_text = ax1.text( + 0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True + ) + + fig.canvas.draw() + bg0 = fig.canvas.copy_from_bbox(ax0.bbox) + bg1 = fig.canvas.copy_from_bbox(ax1.bbox) + + # State + joints_out = [] + joint_idx = 0 + phase = "ellipse" + advance_requested = False + zero_samples = [] + + def on_key(event): + nonlocal advance_requested + if event.key in ("n", "N", "enter", " "): + advance_requested = True + + fig.canvas.mpl_connect("key_press_event", on_key) + + def reset_state(): + return { + "xs": deque(maxlen=params.history), + "ys": deque(maxlen=params.history), + "xu": deque(maxlen=params.history), + "yu": deque(maxlen=params.history), + "win_s": deque(maxlen=params.median_window), + "win_c": deque(maxlen=params.median_window), + "ellipse_cache": None, + "T": None, + "center_fit": None, + "have_transform": False, + "latest_z": None, + "last_fit": 0.0, + } + + state = reset_state() + last_draw = 0.0 + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE") + ax0.set_title(f"{name} raw (filtered)") + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + + try: + while plt.fignum_exists(fig.number): + name, pair = joint_list[joint_idx] + + # Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ... + if phase == "ellipse" and advance_requested and state["have_transform"]: + joints_out.append( + { + "name": name, + "center_fit": state["center_fit"].tolist(), + "T": state["T"].tolist(), + } + ) + logger.info(f" -> Ellipse saved for {name}") + phase, zero_samples, advance_requested = "zero_pose", [], False + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE") + ax0.set_title(f"{name} - hold zero pose") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"Step 2: Hold {name} in zero position, then press 'n'") + + elif phase == "ellipse" and advance_requested and not state["have_transform"]: + logger.info(" (Need valid fit first - keep moving the joint)") + advance_requested = False + + elif phase == "zero_pose" and advance_requested: + if len(zero_samples) >= params.sample_count: + zero_offset = float(np.mean(zero_samples[-params.sample_count :])) + joints_out[-1]["zero_offset"] = zero_offset + logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)") + joint_idx += 1 + advance_requested = False + + if joint_idx >= len(joint_list): + # All joints done + calib = ExoskeletonCalibration( + version=2, + side=side, + adc_max=2**12 - 1, + joints=[ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in joints_out + ], + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w") as f: + json.dump(calib.to_dict(), f, indent=2) + logger.info(f"Saved calibration to {save_path}") + logger.info("Calibration complete!") + plt.close(fig) + return calib + + # Next joint + phase, state = "ellipse", reset_state() + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title( + f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE" + ) + ax0.set_title(f"{name} raw (filtered)") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + else: + logger.info( + f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)" + ) + advance_requested = False + + # Read sensor + raw16 = read_raw_from_serial(ser) + if raw16 is not None: + x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair) + + if phase == "ellipse": + if state["have_transform"]: + z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"]) + state["xu"].append(float(z[0])) + state["yu"].append(float(z[1])) + state["latest_z"] = (float(z[0]), float(z[1])) + state["win_s"].append(s_raw) + state["win_c"].append(c_raw) + if len(state["win_s"]) >= max(3, params.median_window): + state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2) + state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2) + else: + jdata = joints_out[-1] + z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"])) + zero_samples.append(float(np.arctan2(z[1], z[0]))) + state["latest_z"] = (float(z[0]), float(z[1])) + + # Ellipse fitting + t = time.time() + if ( + phase == "ellipse" + and (t - state["last_fit"]) >= params.fit_every + and len(state["xs"]) >= params.min_fit_points + ): + xfit, yfit = select_fit_subset(state["xs"], state["ys"]) + if xfit is not None and len(xfit) >= params.min_fit_points: + fit = fit_ellipse_opencv(xfit, yfit) + if fit is not None: + state["center_fit"] = fit["center"] + state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T + state["ellipse_cache"] = (fit["ex"], fit["ey"]) + state["have_transform"] = True + state["last_fit"] = t + + # Drawing + if (t - last_draw) >= 1.0 / params.draw_hz: + fig.canvas.restore_region(bg0) + fig.canvas.restore_region(bg1) + + if phase == "ellipse": + sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], [])) + ax0.draw_artist(ell_line) + sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2))) + ax1.draw_artist(sc1) + if state["latest_z"]: + zx, zy = state["latest_z"] + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance" + ) + else: + radius_line.set_data([], []) + angle_text.set_text("(waiting for fit)") + else: + sc0.set_offsets(np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data([], []) + ax0.draw_artist(ell_line) + if state["latest_z"]: + zx, zy = state["latest_z"] + sc1.set_offsets([[zx, zy]]) + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'" + ) + else: + sc1.set_offsets(np.empty((0, 2))) + radius_line.set_data([], []) + angle_text.set_text("(waiting for data)") + ax1.draw_artist(sc1) + + ax1.draw_artist(radius_line) + ax1.draw_artist(angle_text) + fig.canvas.blit(ax0.bbox) + fig.canvas.blit(ax1.bbox) + fig.canvas.flush_events() + last_draw = t + + plt.pause(0.001) + + finally: + plt.close(fig) diff --git a/src/lerobot/teleoperators/unitree_g1/exo_ik.py b/src/lerobot/teleoperators/unitree_g1/exo_ik.py new file mode 100644 index 000000000..92519540f --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_ik.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame, +visualizing the result in meshcat after calibration. +""" + +import logging +import os +from dataclasses import dataclass + +import numpy as np + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK + +from .exo_calib import JOINTS + +logger = logging.getLogger(__name__) + + +def _frame_id(model, name: str) -> int | None: + try: + fid = model.getFrameId(name) + return fid if 0 <= fid < model.nframes else None + except Exception: + return None + + +@dataclass +class ArmCfg: + side: str # "left" | "right" + urdf: str # exo_left.urdf / exo_right.urdf + root: str # "exo_left" / "exo_right" + g1_ee: str # "l_ee" / "r_ee" + offset: np.ndarray # world offset for viz + target + marker_prefix: str # "left" / "right" + + +class Markers: + """Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1""" + + def __init__(self, viewer): + self.v = viewer + + def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]): + import meshcat.geometry as mg + + c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255) + self.v[path].set_object( + mg.Sphere(r), + mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0), + ) + + def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6): + import meshcat.geometry as mg + + pts = np.array( + [[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]], + dtype=np.float32, + ).T + cols = np.array( + [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], + dtype=np.float32, + ).T + self.v[path].set_object( + mg.LineSegments( + mg.PointsGeometry(position=pts, color=cols), + mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True), + ) + ) + + def tf(self, path: str, mat: np.ndarray): + self.v[path].set_transform(mat) + + +class ExoskeletonIKHelper: + """ + - Loads G1 robot and exoskeleton URDF models via Pinocchio + - Computes forward kinematics on exoskeleton to get end-effector poses + - Solves inverse kinematics on G1 to match those poses + - Provides meshcat visualization showing both robots and targets + + Args: + frozen_joints: List of G1 joint names to exclude from IK (kept at neutral). + """ + + def __init__(self, frozen_joints: list[str] | None = None): + try: + import pinocchio as pin + except ImportError as e: + raise ImportError("ik mode needs pinocchio: pip install pin") from e + + self.pin = pin + self.frozen_joints = frozen_joints or [] + + self.g1_ik = G1_29_ArmIK() + self.robot_g1 = self.g1_ik.reduced_robot + self.robot_g1.data = self.robot_g1.model.createData() + self.q_g1 = pin.neutral(self.robot_g1.model) + + assets_dir = os.path.join(self.g1_ik.repo_path, "assets") + + self.frozen_idx = self._frozen_joint_indices() + + self.arms = [ + ArmCfg( + side="left", + urdf=os.path.join(assets_dir, "exo_left.urdf"), + root="exo_left", + g1_ee="L_ee", + offset=np.array([0.6, 0.3, 0.0]), + marker_prefix="left", + ), + ArmCfg( + side="right", + urdf=os.path.join(assets_dir, "exo_right.urdf"), + root="exo_right", + g1_ee="R_ee", + offset=np.array([0.6, -0.3, 0.0]), + marker_prefix="right", + ), + ] + + self.exo = {} # side -> pin.RobotWrapper + self.q_exo = {} # side -> q + self.ee_id_exo = {} # side -> frame id + self.qmap = {} # side -> {joint_name: q_idx} + self.ee_id_g1 = {} # side -> frame id + + self._load_exo_models(assets_dir) + for a in self.arms: + self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee) + + self.viewer = None + self.markers: Markers | None = None + self.viz_g1 = None + self.viz_exo = {} # side -> viz + + def _frozen_joint_indices(self) -> dict[str, int]: + out = {} + m = self.robot_g1.model + for name in self.frozen_joints: + if name in m.names: + jid = m.getJointId(name) + out[name] = m.idx_qs[jid] + logger.info(f"freezing joint: {name} (q_idx={out[name]})") + return out + + def _find_exo_ee(self, model, ee_name: str = "ee") -> int: + ee = _frame_id(model, ee_name) + if ee is not None: + return ee + for fid in reversed(range(model.nframes)): + if model.frames[fid].type == self.pin.FrameType.BODY: + return fid + return 0 + + def _build_joint_map(self, robot) -> dict[str, int]: + m = robot.model + return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names} + + def _load_exo_models(self, assets_dir: str): + pin = self.pin + for a in self.arms: + if not os.path.exists(a.urdf): + logger.warning(f"{a.side} exo urdf not found: {a.urdf}") + continue + r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir) + self.exo[a.side] = r + self.q_exo[a.side] = pin.neutral(r.model) + self.ee_id_exo[a.side] = self._find_exo_ee(r.model) + self.qmap[a.side] = self._build_joint_map(r) + logger.info(f"loaded {a.side} exo urdf: {a.urdf}") + + def init_visualization(self): + """ + Creates a browser-based visualization of exoskeleton and G1 robot, + highlighting end-effector frames and target positions. + """ + try: + from pinocchio.visualize import MeshcatVisualizer + except ImportError as e: + logger.warning(f"meshcat viz unavailable: {e}") + return + + # g1 + self.viz_g1 = MeshcatVisualizer( + self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model + ) + self.viz_g1.initViewer(open=True) + self.viz_g1.loadViewerModel("g1") + self.viz_g1.display(self.q_g1) + + self.viewer = self.viz_g1.viewer + self.markers = Markers(self.viewer) + + # exos + for a in self.arms: + if a.side not in self.exo: + continue + r = self.exo[a.side] + v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model) + v.initViewer(open=False) + v.viewer = self.viewer + v.loadViewerModel(a.root) + offset_tf = np.eye(4) + offset_tf[:3, 3] = a.offset + self.viewer[a.root].set_transform(offset_tf) + v.display(self.q_exo[a.side]) + self.viz_exo[a.side] = v + + # markers + for a in self.arms: + p = a.marker_prefix + self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9)) + self.markers.axes(f"markers/{p}_exo_axes", 0.06) + self.markers.axes(f"markers/{p}_g1_axes", 0.08) + + logger.info(f"meshcat viz initialized: {self.viewer.url()}") + print(f"\nmeshcat url: {self.viewer.url()}\n") + + def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None: + """returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account.""" + if side not in self.exo or not angles: + return None + + pin = self.pin + q = self.q_exo[side] + qmap = self.qmap[side] + + for name, ang in angles.items(): + idx = qmap.get(name) + if idx is not None: + q[idx] = float(ang) + + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, q) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + target = np.eye(4) + target[:3, :3] = ee.rotation + # offset gets applied in world space + cfg = next(a for a in self.arms if a.side == side) + target[:3, 3] = cfg.offset + ee.translation + return target + + def update_visualization(self): + if self.viewer is None or self.markers is None: + return + + pin = self.pin + + # g1 + if self.viz_g1 is not None: + self.viz_g1.display(self.q_g1) + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + fid = self.ee_id_g1.get(a.side) + if fid is None: + continue + ee_tf = self.robot_g1.data.oMf[fid].homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_g1_ee", ee_tf) + self.markers.tf(f"markers/{p}_g1_axes", ee_tf) + + # exos + for a in self.arms: + side = a.side + v = self.viz_exo.get(side) + if v is None: + continue + + v.display(self.q_exo[side]) + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, self.q_exo[side]) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_exo_ee", world_tf) + self.markers.tf(f"markers/{p}_exo_axes", world_tf) + + target_tf = np.eye(4) + target_tf[:3, :3] = ee.rotation + target_tf[:3, 3] = a.offset + ee.translation + self.markers.tf(f"markers/{p}_ik_target", target_tf) + + def compute_g1_joints_from_exo( + self, + left_angles: dict[str, float], + right_angles: dict[str, float], + ) -> dict[str, float]: + """ + Performs FK on exoskeleton to get end-effector poses in world frame, + after which it solves IK on G1 to return joint angles matching those poses in G1 motor order. + """ + pin = self.pin + + targets = { + "left": self._fk_target_world("left", left_angles), + "right": self._fk_target_world("right", right_angles), + } + + # fallback to current g1 ee pose if missing target + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + if targets[a.side] is not None: + continue + fid = self.ee_id_g1.get(a.side) + if fid is not None: + targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous + + if targets["left"] is None or targets["right"] is None: + logger.warning("missing ik targets, returning current pose") + return {} + + frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()} + + self.q_g1, _ = self.g1_ik.solve_ik( + targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1 + ) + + for n, i in self.frozen_idx.items(): + self.q_g1[i] = frozen_vals[n] + + return { + f"{j.name}.q": float(self.q_g1[i]) + for i, j in enumerate(G1_29_JointArmIndex) + if i < len(self.q_g1) + } diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py new file mode 100644 index 000000000..1211c57cc --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from dataclasses import dataclass +from pathlib import Path + +import serial + +from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration + +logger = logging.getLogger(__name__) + + +def parse_raw16(line: bytes) -> list[int] | None: + try: + parts = line.decode("utf-8", errors="ignore").split() + if len(parts) < 16: + return None + return [int(x) for x in parts[:16]] + except Exception: + return None + + +def read_raw_from_serial(ser) -> list[int] | None: + """Read latest sample from serial; if buffer is backed up, keep only the newest.""" + last = None + while ser.in_waiting > 0: + b = ser.readline() + if not b: + break + raw16 = parse_raw16(b) + if raw16 is not None: + last = raw16 + if last is None: + b = ser.readline() + if b: + last = parse_raw16(b) + return last + + +@dataclass +class ExoskeletonArm: + port: str + calibration_fpath: Path + side: str + baud_rate: int = 115200 + + _ser: serial.Serial | None = None + calibration: ExoskeletonCalibration | None = None + + def __post_init__(self): + if self.calibration_fpath.is_file(): + self._load_calibration() + + @property + def is_connected(self) -> bool: + return self._ser is not None and getattr(self._ser, "is_open", False) + + @property + def is_calibrated(self) -> bool: + return self.calibration is not None + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + return + try: + self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02) + self._ser.reset_input_buffer() + logger.info(f"connected: {self.port}") + except serial.SerialException as e: + raise ConnectionError(f"failed to connect to {self.port}: {e}") from e + + if calibrate and not self.is_calibrated: + self.calibrate() + + def disconnect(self) -> None: + if self._ser: + try: + self._ser.close() + finally: + self._ser = None + + def _load_calibration(self) -> None: + try: + data = json.loads(self.calibration_fpath.read_text()) + self.calibration = ExoskeletonCalibration.from_dict(data) + logger.info(f"loaded calibration: {self.calibration_fpath}") + except Exception as e: + logger.warning(f"failed to load calibration: {e}") + + def read_raw(self) -> list[int] | None: + if not self._ser: + return None + return read_raw_from_serial(self._ser) + + def get_angles(self) -> dict[str, float]: + if not self.calibration: + raise RuntimeError("exoskeleton not calibrated") + raw = self.read_raw() + return {} if raw is None else exo_raw_to_angles(raw, self.calibration) + + def calibrate(self) -> None: + ser = self._ser + self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath) diff --git a/src/lerobot/teleoperators/unitree_g1/unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py new file mode 100644 index 000000000..3779d83ec --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS + +from ..teleoperator import Teleoperator +from .config_unitree_g1 import UnitreeG1TeleoperatorConfig +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm + +logger = logging.getLogger(__name__) + + +class UnitreeG1Teleoperator(Teleoperator): + """ + Bimanual exoskeleton arms teleoperator for Unitree G1 arms. + + Uses inverse kinematics: exoskeleton FK computes end-effector pose, + G1 IK solves for joint angles. + """ + + config_class = UnitreeG1TeleoperatorConfig + name = "unitree_g1" + + def __init__(self, config: UnitreeG1TeleoperatorConfig): + super().__init__(config) + self.config = config + + # Setup calibration directory + self.calibration_dir = ( + config.calibration_dir + if config.calibration_dir + else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + + left_id = f"{config.id}_left" if config.id else "left" + right_id = f"{config.id}_right" if config.id else "right" + + # Create exoskeleton arm instances + self.left_arm = ExoskeletonArm( + port=config.left_arm_config.port, + baud_rate=config.left_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{left_id}.json", + side="left", + ) + self.right_arm = ExoskeletonArm( + port=config.right_arm_config.port, + baud_rate=config.right_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{right_id}.json", + side="right", + ) + + self.ik_helper: ExoskeletonIKHelper | None = None + + @cached_property + def action_features(self) -> dict[str, type]: + return {f"{name}.q": float for name in self._g1_joint_names} + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + logger.info("IK helper initialized") + + def calibrate(self) -> None: + if not self.left_arm.is_calibrated: + logger.info("Starting calibration for left arm...") + self.left_arm.calibrate() + else: + logger.info("Left arm already calibrated. Skipping.") + + if not self.right_arm.is_calibrated: + logger.info("Starting calibration for right arm...") + self.right_arm.calibrate() + else: + logger.info("Right arm already calibrated. Skipping.") + + logger.info("Starting visualization to verify calibration...") + self.run_visualization_loop() + + def configure(self) -> None: + pass + + def get_action(self) -> dict[str, float]: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Exoskeleton arms do not support feedback") + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() + + def run_visualization_loop(self): + """Run interactive Meshcat visualization loop to verify tracking.""" + if self.ik_helper is None: + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + + self.ik_helper.init_visualization() + + print("\n" + "=" * 60) + print("Visualization running! Move the exoskeletons to test tracking.") + print("Press Ctrl+C to exit.") + print("=" * 60 + "\n") + + try: + while True: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + + self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + self.ik_helper.update_visualization() + + time.sleep(0.01) + + except KeyboardInterrupt: + print("\n\nVisualization stopped.") + + @cached_property + def _g1_joint_names(self) -> list[str]: + return [joint.name for joint in G1_29_JointIndex] diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 8f6bbc787..3b42d294e 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -75,6 +75,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .homunculus import HomunculusArm return HomunculusArm(config) + elif config.type == "unitree_g1": + from .unitree_g1 import UnitreeG1Teleoperator + + return UnitreeG1Teleoperator(config) elif config.type == "bi_so_leader": from .bi_so_leader import BiSOLeader From 4483184875e0c283066ba304e2404638f8aa803e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 17:25:57 +0100 Subject: [PATCH 08/43] feat(robots): add bi manual openarm follower and leader (#2835) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * feat(robots): add bi manual openarm follower and leader * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * mock class * ignore linter * pre-commit * Add command for bimanual openarm * fix import * fix import leader * fix import draccus --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/openarm.mdx | 18 ++ .../robots/bi_openarm_follower/__init__.py | 20 ++ .../bi_openarm_follower.py | 175 ++++++++++++++++++ .../config_bi_openarm_follower.py | 30 +++ .../robots/openarm_follower/__init__.py | 4 +- .../config_openarm_follower.py | 11 +- src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../bi_openarm_leader/__init__.py | 20 ++ .../bi_openarm_leader/bi_openarm_leader.py | 131 +++++++++++++ .../config_bi_openarm_leader.py | 30 +++ .../teleoperators/openarm_leader/__init__.py | 4 +- .../openarm_leader/config_openarm_leader.py | 11 +- src/lerobot/teleoperators/utils.py | 4 + 18 files changed, 461 insertions(+), 10 deletions(-) create mode 100644 src/lerobot/robots/bi_openarm_follower/__init__.py create mode 100644 src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py create mode 100644 src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx index 661808749..cd4ace912 100644 --- a/docs/source/openarm.mdx +++ b/docs/source/openarm.mdx @@ -174,6 +174,24 @@ lerobot-teleoperate \ --teleop.id=my_leader ``` +### Bimanual Teleoperation + +To teleoperate a bimanual OpenArm setup with two leader and two follower arms: + +```bash +lerobot-teleoperate \ + --robot.type=bi_openarm_follower \ + --robot.left_arm_config.port=can0 \ + --robot.left_arm_config.side=left \ + --robot.right_arm_config.port=can1 \ + --robot.right_arm_config.side=right \ + --robot.id=my_bimanual_follower \ + --teleop.type=bi_openarm_leader \ + --teleop.left_arm_config.port=can2 \ + --teleop.right_arm_config.port=can3 \ + --teleop.id=my_bimanual_leader +``` + ### Recording Data To record a dataset during teleoperation: diff --git a/src/lerobot/robots/bi_openarm_follower/__init__.py b/src/lerobot/robots/bi_openarm_follower/__init__.py new file mode 100644 index 000000000..b1dcce431 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_follower import BiOpenArmFollower +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"] diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py new file mode 100644 index 000000000..466eb07e5 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction, RobotObservation +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +from ..robot import Robot +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmFollower(Robot): + """ + Bimanual OpenArm Follower Arms + """ + + config_class = BiOpenArmFollowerConfig + name = "bi_openarm_follower" + + def __init__(self, config: BiOpenArmFollowerConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, + max_relative_target=config.left_arm_config.max_relative_target, + cameras=config.left_arm_config.cameras, + side=config.left_arm_config.side, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + joint_limits=config.left_arm_config.joint_limits, + ) + + right_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, + max_relative_target=config.right_arm_config.max_relative_target, + cameras=config.right_arm_config.cameras, + side=config.right_arm_config.side, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + joint_limits=config.right_arm_config.joint_limits, + ) + + self.left_arm = OpenArmFollower(left_arm_config) + self.right_arm = OpenArmFollower(right_arm_config) + + # Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute + self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras} + + @property + def _motors_ft(self) -> dict[str, type]: + left_arm_motors_ft = self.left_arm._motors_ft + right_arm_motors_ft = self.right_arm._motors_ft + + return { + **{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, + } + + @property + def _cameras_ft(self) -> dict[str, tuple]: + left_arm_cameras_ft = self.left_arm._cameras_ft + right_arm_cameras_ft = self.right_arm._cameras_ft + + return { + **{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + obs_dict = {} + + # Add "left_" prefix + left_obs = self.left_arm.get_observation() + obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) + + # Add "right_" prefix + right_obs = self.right_arm.get_observation() + obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + # Remove "left_" prefix + left_action = { + key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") + } + # Remove "right_" prefix + right_action = { + key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_") + } + + sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd) + sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd) + + # Add prefixes back + prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} + prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} + + return {**prefixed_sent_action_left, **prefixed_sent_action_right} + + def disconnect(self): + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py new file mode 100644 index 000000000..9d11f7b4e --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("bi_openarm_follower") +@dataclass +class BiOpenArmFollowerConfig(RobotConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmFollowerConfigBase + right_arm_config: OpenArmFollowerConfigBase diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py index 1eb0d9fc7..217432fd5 100644 --- a/src/lerobot/robots/openarm_follower/__init__.py +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_follower import OpenArmFollowerConfig +from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase from .openarm_follower import OpenArmFollower -__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py index af95b6395..88d81fd50 100644 --- a/src/lerobot/robots/openarm_follower/config_openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -43,10 +43,9 @@ RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { } -@RobotConfig.register_subclass("openarm_follower") @dataclass -class OpenArmFollowerConfig(RobotConfig): - """Configuration for the OpenArms follower robot with Damiao motors.""" +class OpenArmFollowerConfigBase: + """Base configuration for the OpenArms follower robot with Damiao motors.""" # CAN interfaces - one per arm # arm CAN interface (e.g., "can1") @@ -115,3 +114,9 @@ class OpenArmFollowerConfig(RobotConfig): "gripper": (-5.0, 0.0), } ) + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase): + pass diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index e0c76cab3..92da597f1 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -64,6 +64,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .openarm_follower import OpenArmFollower return OpenArmFollower(config) + elif config.type == "bi_openarm_follower": + from .bi_openarm_follower import BiOpenArmFollower + + return BiOpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 2fa1b2a03..eb3df6872 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, hope_jr, koch_follower, @@ -48,6 +49,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index d928dc5cd..082d11803 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -44,6 +44,7 @@ import numpy as np from lerobot.model.kinematics import RobotKinematics from lerobot.robots import ( # noqa: F401 RobotConfig, + bi_openarm_follower, bi_so_follower, koch_follower, make_robot_from_config, @@ -53,6 +54,7 @@ from lerobot.robots import ( # noqa: F401 ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, koch_leader, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d621189e8..0b39e6fff 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -98,6 +98,7 @@ from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -112,6 +113,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c3bc3d766..5717dffb6 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -53,6 +53,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 958bd00ef..b6aa4a750 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -70,6 +70,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -84,6 +85,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, homunculus, diff --git a/src/lerobot/teleoperators/bi_openarm_leader/__init__.py b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py new file mode 100644 index 000000000..fe728b826 --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_leader import BiOpenArmLeader +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py new file mode 100644 index 000000000..c4383293f --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig + +from ..openarm_leader import OpenArmLeader +from ..teleoperator import Teleoperator +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmLeader(Teleoperator): + """ + Bimanual OpenArm Leader Arms + """ + + config_class = BiOpenArmLeaderConfig + name = "bi_openarm_leader" + + def __init__(self, config: BiOpenArmLeaderConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + manual_control=config.left_arm_config.manual_control, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + ) + + right_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + manual_control=config.right_arm_config.manual_control, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + ) + + self.left_arm = OpenArmLeader(left_arm_config) + self.right_arm = OpenArmLeader(right_arm_config) + + @cached_property + def action_features(self) -> dict[str, type]: + left_arm_features = self.left_arm.action_features + right_arm_features = self.right_arm.action_features + + return { + **{f"left_{k}": v for k, v in left_arm_features.items()}, + **{f"right_{k}": v for k, v in right_arm_features.items()}, + } + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + action_dict = {} + + # Add "left_" prefix + left_action = self.left_arm.get_action() + action_dict.update({f"left_{key}": value for key, value in left_action.items()}) + + # Add "right_" prefix + right_action = self.right_arm.get_action() + action_dict.update({f"right_{key}": value for key, value in right_action.items()}) + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO: Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py new file mode 100644 index 000000000..39fc90add --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("bi_openarm_leader") +@dataclass +class BiOpenArmLeaderConfig(TeleoperatorConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmLeaderConfigBase + right_arm_config: OpenArmLeaderConfigBase diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py index 1493317fe..172cf8228 100644 --- a/src/lerobot/teleoperators/openarm_leader/__init__.py +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_leader import OpenArmLeaderConfig +from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase from .openarm_leader import OpenArmLeader -__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py index c53169b0a..4b12fe730 100644 --- a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -19,10 +19,9 @@ from dataclasses import dataclass, field from ..config import TeleoperatorConfig -@TeleoperatorConfig.register_subclass("openarm_leader") @dataclass -class OpenArmLeaderConfig(TeleoperatorConfig): - """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" +class OpenArmLeaderConfigBase: + """Base configuration for the OpenArms leader/teleoperator with Damiao motors.""" # CAN interfaces - one per arm # Arm CAN interface (e.g., "can3") @@ -68,3 +67,9 @@ class OpenArmLeaderConfig(TeleoperatorConfig): default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] ) position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase): + pass diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 3b42d294e..16454d5ad 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -91,6 +91,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .openarm_leader import OpenArmLeader return OpenArmLeader(config) + elif config.type == "bi_openarm_leader": + from .bi_openarm_leader import BiOpenArmLeader + + return BiOpenArmLeader(config) else: try: return cast("Teleoperator", make_device_from_device_class(config)) From 3409ef0dc2fc948468ec573a1e47f8ab7747cee2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 29 Jan 2026 04:07:47 -0600 Subject: [PATCH 09/43] refactor(cameras): cameras API extension (#2808) * feat(cameras): add new read_latest() method * fix(cameras): fix threading bug + clear state * refactor(cameras): multiple improvements * feat(camera): add context manager to camera base class * chore(camera): slight modifications to opencv * test(cameras): update opencv tests according to the changes * refactor(cameras): reflect desing changes to realsense + deal with depth * test(cameras): fix realsense tests accordingly to new changes * refactor(cameras): update reachymini and zmq accordingly * chore: wrap resource sensitive examples into a try/finally * test(cameras): add test for new read_latest * test(cameras): fix problem with image artifact in opencv tests * test(cameras): fix test_read_latest_high_frequency expectations * Apply suggestions from code review 1 Co-authored-by: Caroline Pascal Signed-off-by: Steven Palma * chore(cameras): address feedback * feat(cameras): add max_age_ms check in read_latest * test(cameras): fix read_latest tests * chore(redundancies): removing redundancies in Reachy 2 camera class * fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class * chore(format): formatting latest changes * chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class * chore(warning): making read_latest() warning message shorter and clearer --------- Signed-off-by: Steven Palma Co-authored-by: Caroline Pascal --- examples/backward_compatibility/replay.py | 31 +-- examples/lekiwi/evaluate.py | 88 +++---- examples/lekiwi/record.py | 89 +++---- examples/lekiwi/replay.py | 32 +-- examples/phone_to_so100/evaluate.py | 85 ++++--- examples/phone_to_so100/record.py | 85 ++++--- examples/phone_to_so100/replay.py | 42 ++-- examples/so100_to_so100_EE/evaluate.py | 85 ++++--- examples/so100_to_so100_EE/record.py | 86 ++++--- examples/so100_to_so100_EE/replay.py | 41 +-- src/lerobot/cameras/camera.py | 100 ++++++-- src/lerobot/cameras/opencv/camera_opencv.py | 166 ++++++++---- .../cameras/reachy2_camera/reachy2_camera.py | 67 +++-- .../cameras/realsense/camera_realsense.py | 212 +++++++++++----- src/lerobot/cameras/zmq/camera_zmq.py | 238 ++++++++++++++---- src/lerobot/cameras/zmq/configuration_zmq.py | 1 + src/lerobot/scripts/lerobot_calibrate.py | 7 +- src/lerobot/scripts/lerobot_replay.py | 29 +-- tests/cameras/test_opencv.py | 184 +++++++++----- tests/cameras/test_reachy2_camera.py | 38 +++ tests/cameras/test_realsense.py | 114 +++++---- 21 files changed, 1179 insertions(+), 641 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index ed78d016f..8de5ba197 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig): actions = dataset.hf_dataset.select_columns(ACTION) robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - key = f"{name.removeprefix('main_')}.pos" - action[key] = action_array[i].item() + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + key = f"{name.removeprefix('main_')}.pos" + action[key] = action_array[i].item() - action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) - action["elbow_flex.pos"] -= 90 - robot.send_action(action) + action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) + action["elbow_flex.pos"] -= 90 + robot.send_action(action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 2f7f9f95f..a3144a442 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -78,40 +78,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") + print("Starting evaluate loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -120,24 +104,42 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 18b9f857e..9292157f7 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -74,40 +74,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_record") - if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {recorded_episodes}") + print("Starting record loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {recorded_episodes}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - dataset=dataset, - teleop=[leader_arm, keyboard], - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + dataset=dataset, teleop=[leader_arm, keyboard], - control_time_s=RESET_TIME_SEC, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=teleop_action_processor, @@ -115,26 +98,44 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=[leader_arm, keyboard], + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - leader_arm.disconnect() - keyboard.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + leader_arm.disconnect() + keyboard.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 872dacf27..cf89aea16 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -42,25 +42,27 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Send action to robot - _ = robot.send_action(action) + # Send action to robot + _ = robot.send_action(action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 246c923aa..837217eda 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 7b5b704e2..1f5005db9 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -149,38 +149,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_record") - if not robot.is_connected or not phone.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not phone.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop. Move your phone to teleoperate the robot...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop. Move your phone to teleoperate the robot...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - teleop=phone, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=phone_to_robot_ee_pose_processor, - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, teleop=phone, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=phone_to_robot_ee_pose_processor, @@ -188,25 +173,43 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - phone.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + phone.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 875025dfc..9d7806cf4 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -73,32 +73,34 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - # Clean up - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 87d188f99..b614b89f2 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="so100_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index eead7a9a8..d85a1c5cc 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -146,38 +146,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="recording_phone") - if not leader.is_connected or not follower.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not leader.is_connected or not follower.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=follower, - events=events, - fps=FPS, - teleop=leader, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=leader_joints_to_ee, - robot_action_processor=ee_to_follower_joints, - robot_observation_processor=follower_joints_to_ee, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=follower, events=events, fps=FPS, teleop=leader, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=leader_joints_to_ee, @@ -185,25 +170,44 @@ def main(): robot_observation_processor=follower_joints_to_ee, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - leader.disconnect() - follower.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + leader.disconnect() + follower.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 7d35a7b44..47a2f6635 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -74,32 +74,35 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - # Clean up - robot.disconnect() + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index bfdb571a7..2894e0215 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -15,11 +15,12 @@ # limitations under the License. import abc +import warnings from typing import Any from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing -from .configs import CameraConfig, ColorMode +from .configs import CameraConfig class Camera(abc.ABC): @@ -30,20 +31,12 @@ class Camera(abc.ABC): Manages basic camera properties (FPS, resolution) and core operations: - Connection/disconnection - - Frame capture (sync/async) + - Frame capture (sync/async/latest) Attributes: fps (int | None): Configured frames per second width (int | None): Frame width in pixels height (int | None): Frame height in pixels - - Example: - class MyCamera(Camera): - def __init__(self, config): ... - @property - def is_connected(self) -> bool: ... - def connect(self, warmup=True): ... - # Plus other required methods """ def __init__(self, config: CameraConfig): @@ -56,6 +49,32 @@ class Camera(abc.ABC): self.width: int | None = config.width self.height: int | None = config.height + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + @property @abc.abstractmethod def is_connected(self) -> bool: @@ -89,12 +108,10 @@ class Camera(abc.ABC): pass @abc.abstractmethod - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: - """Capture and return a single frame from the camera. + def read(self) -> NDArray[Any]: + """Capture and return a single frame from the camera synchronously. - Args: - color_mode: Desired color mode for the output frame. If None, - uses the camera's default color mode. + This is a blocking call that will wait for the hardware and its SDK. Returns: np.ndarray: Captured frame as a numpy array. @@ -103,17 +120,64 @@ class Camera(abc.ABC): @abc.abstractmethod def async_read(self, timeout_ms: float = ...) -> NDArray[Any]: - """Asynchronously capture and return a single frame from the camera. + """Return the most recent new frame. + + This method retrieves the latest frame captured by the background thread. + If a new frame is already available in the buffer (captured since the last call), + it returns it immediately. + + It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame + was already consumed by a previous `async_read` call. + + Essentially, this method return the latest unconsumed frame, waiting if necessary + for a new one to arrive within the specified timeout. + + Usage: + - Ideal for control loops where you want to ensure every processed frame + is fresh, effectively synchronizing your loop to the camera's FPS. + - Causes of a timeout usually include: very low camera FPS, heavy processing load, + or if the camera is disconnected. Args: - timeout_ms: Maximum time to wait for a frame in milliseconds. - Defaults to implementation-specific timeout. + timeout_ms: Maximum time to wait for a new frame in milliseconds. + Defaults to 200ms (0.2s). Returns: np.ndarray: Captured frame as a numpy array. + + Raises: + TimeoutError: If no new frame arrives within `timeout_ms`. """ pass + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Usage: + Ideal for scenarios requiring zero latency or decoupled frequencies & when + we want a guaranteed frame, such as UI visualization, logging, or + non-critical monitoring. + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + NotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + warnings.warn( + f"{self.__class__.__name__}.read_latest() is not implemented. " + "Please override read_latest(); it will be required in future releases.", + FutureWarning, + stacklevel=2, + ) + return self.async_read() + @abc.abstractmethod def disconnect(self) -> None: """Disconnect from the camera and release resources.""" diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index b1043ba64..d581e1425 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -70,34 +70,24 @@ class OpenCVCamera(Camera): Example: ```python from lerobot.cameras.opencv import OpenCVCamera - from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation + from lerobot.cameras.configuration_opencv import OpenCVCameraConfig # Basic usage with camera index 0 config = OpenCVCameraConfig(index_or_path=0) camera = OpenCVCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + # When done, properly disconnect the camera using camera.disconnect() - - # Example with custom settings - custom_config = OpenCVCameraConfig( - index_or_path='/dev/video0', # Or use an index - fps=30, - width=1280, - height=720, - color_mode=ColorMode.RGB, - rotation=Cv2Rotation.ROTATE_90 - ) - custom_camera = OpenCVCamera(custom_config) - # ... connect, read, disconnect ... ``` """ @@ -123,6 +113,7 @@ class OpenCVCamera(Camera): self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -146,12 +137,16 @@ class OpenCVCamera(Camera): Connects to the OpenCV camera specified in the configuration. Initializes the OpenCV VideoCapture object, sets desired camera properties - (FPS, width, height), and performs initial checks. + (FPS, width, height), starts the background reading thread and performs initial checks. + + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. Raises: DeviceAlreadyConnectedError: If the camera is already connected. - ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open. - RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings. + ConnectionError: If the specified camera index/path is not found or fails to open. + RuntimeError: If the camera opens but fails to apply requested settings. """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -170,12 +165,16 @@ class OpenCVCamera(Camera): ) self._configure_capture_settings() + self._start_read_thread() - if warmup: + if warmup and self.warmup_s > 0: start_time = time.time() while time.time() - start_time < self.warmup_s: - self.read() + self.async_read(timeout_ms=self.warmup_s * 1000) time.sleep(0.1) + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -196,8 +195,7 @@ class OpenCVCamera(Camera): Raises: RuntimeError: If the camera fails to set any of the specified properties to the requested value. - DeviceNotConnectedError: If the camera is not connected when attempting - to configure settings. + DeviceNotConnectedError: If the camera is not connected. """ if not self.is_connected: raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") @@ -339,6 +337,17 @@ class OpenCVCamera(Camera): return found_cameras_info + def _read_from_hardware(self) -> NDArray[Any]: + if self.videocapture is None: + raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + + ret, frame = self.videocapture.read() + + if not ret: + raise RuntimeError(f"{self} read failed (status={ret}).") + + return frame + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -346,11 +355,6 @@ class OpenCVCamera(Camera): This is a blocking call. It waits for the next available frame from the camera hardware via OpenCV. - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). - Returns: np.ndarray: The captured frame as a NumPy array in the format (height, width, channels), using the specified or default @@ -362,34 +366,34 @@ class OpenCVCamera(Camera): received frame dimensions don't match expectations before rotation. ValueError: If an invalid `color_mode` is requested. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") start_time = time.perf_counter() - if self.videocapture is None: - raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - ret, frame = self.videocapture.read() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - processed_frame = self._postprocess_image(frame, color_mode) + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return processed_frame + return frame - def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw frame. Args: image (np.ndarray): The raw image frame (expected BGR format from OpenCV). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame. @@ -399,11 +403,10 @@ class OpenCVCamera(Camera): RuntimeError: If the raw frame dimensions do not match the configured `width` and `height`. """ - requested_color_mode = self.color_mode if color_mode is None else color_mode - if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) h, w, c = image.shape @@ -417,7 +420,7 @@ class OpenCVCamera(Camera): raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") processed_image = image - if requested_color_mode == ColorMode.RGB: + if self.color_mode == ColorMode.RGB: processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: @@ -431,7 +434,7 @@ class OpenCVCamera(Camera): On each iteration: 1. Reads a color frame - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -439,30 +442,37 @@ class OpenCVCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read() + raw_frame = self._read_from_hardware() + processed_frame = self._postprocess_image(raw_frame) + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_frame = processed_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") self.thread.daemon = True self.thread.start() + time.sleep(0.1) def _stop_read_thread(self) -> None: """Signals the background read thread to stop and waits for it to join.""" @@ -475,6 +485,11 @@ class OpenCVCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -482,6 +497,7 @@ class OpenCVCamera(Camera): This method retrieves the most recent frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -500,13 +516,12 @@ class OpenCVCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: @@ -518,6 +533,42 @@ class OpenCVCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera and cleans up resources. @@ -538,4 +589,9 @@ class OpenCVCamera(Camera): self.videocapture.release() self.videocapture = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index c8916c5ee..5cede466d 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -80,6 +80,8 @@ class Reachy2Camera(Camera): self.config = config self.color_mode = config.color_mode + self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.cam_manager: CameraManager | None = None @@ -125,12 +127,7 @@ class Reachy2Camera(Camera): """ Reads a single frame synchronously from the camera. - This is a blocking call. - - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). + This method retrieves the most recent frame available in Reachy 2's low-level software. Returns: np.ndarray: The captured frame as a NumPy array in the format @@ -145,6 +142,11 @@ class Reachy2Camera(Camera): if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): @@ -165,11 +167,18 @@ class Reachy2Camera(Camera): raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") if frame is None: - return np.empty((0, 0, 3), dtype=np.uint8) + raise RuntimeError(f"Internal error: No frame available for {self}.") - if self.config.color_mode == "rgb": + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + if self.color_mode == ColorMode.RGB: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + self.latest_frame = frame + self.latest_timestamp = time.perf_counter() + read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") @@ -177,13 +186,7 @@ class Reachy2Camera(Camera): def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ - Reads the latest available frame. - - This method retrieves the most recent frame available in Reachy 2's low-level software. - - Args: - timeout_ms (float): Maximum time in milliseconds to wait for a frame - to become available. Defaults to 200ms (0.2 seconds). + Same as read() Returns: np.ndarray: The latest captured frame as a NumPy array in the format @@ -197,12 +200,38 @@ class Reachy2Camera(Camera): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - frame = self.read() + return self.read() - if frame is None: - raise RuntimeError(f"Internal error: No frame available for {self}.") + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). - return frame + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + tuple[NDArray, float]: + - The frame image (numpy array). + - The timestamp (time.perf_counter) when this frame was captured. + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.latest_frame is None or self.latest_timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return self.latest_frame def disconnect(self) -> None: """ diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index f2906fdd8..e47f25381 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -72,15 +72,14 @@ class RealSenseCamera(Camera): camera = RealSenseCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() - # When done, properly disconnect the camera using - camera.disconnect() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() # Example with depth capture and custom settings custom_config = RealSenseCameraConfig( @@ -133,7 +132,9 @@ class RealSenseCamera(Camera): self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() - self.latest_frame: NDArray[Any] | None = None + self.latest_color_frame: NDArray[Any] | None = None + self.latest_depth_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -158,6 +159,10 @@ class RealSenseCamera(Camera): Initializes the RealSense pipeline, configures the required streams (color and optionally depth), starts the pipeline, and validates the actual stream settings. + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. + Raises: DeviceAlreadyConnectedError: If the camera is already connected. ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). @@ -181,15 +186,18 @@ class RealSenseCamera(Camera): ) from e self._configure_capture_settings() + self._start_read_thread() - if warmup: - time.sleep( - 1 - ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise. - start_time = time.time() - while time.time() - start_time < self.warmup_s: - self.read() - time.sleep(0.1) + # NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise. + self.warmup_s = max(self.warmup_s, 1) + + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.async_read(timeout_ms=self.warmup_s * 1000) + time.sleep(0.1) + with self.frame_lock: + if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -319,9 +327,6 @@ class RealSenseCamera(Camera): This is a blocking call. It waits for a coherent set of frames (depth) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The depth map as a NumPy array (height, width) of type `np.uint16` (raw depth values in millimeters) and rotation. @@ -330,44 +335,52 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If reading frames from the pipeline fails or frames are invalid. """ + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if not self.use_depth: raise RuntimeError( f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - start_time = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + + _ = self.async_read(timeout_ms=10000) + + with self.frame_lock: + depth_map = self.latest_depth_frame + + if depth_map is None: + raise RuntimeError("No depth frame available. Ensure camera is streaming.") + + return depth_map + + def _read_from_hardware(self): if self.rs_pipeline is None: raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000) if not ret or frame is None: - raise RuntimeError(f"{self} read_depth failed (status={ret}).") + raise RuntimeError(f"{self} read failed (status={ret}).") - depth_frame = frame.get_depth_frame() - depth_map = np.asanyarray(depth_frame.get_data()) + return frame - depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) - - read_duration_ms = (time.perf_counter() - start_time) * 1e3 - logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - - return depth_map_processed - - def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]: + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. This is a blocking call. It waits for a coherent set of frames (color) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The captured color frame as a NumPy array (height, width, channels), processed according to `color_mode` and rotation. @@ -378,39 +391,39 @@ class RealSenseCamera(Camera): ValueError: If an invalid `color_mode` is requested. """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) + if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - start_time = time.perf_counter() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - if self.rs_pipeline is None: - raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") + self.new_frame_event.clear() - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) - - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") - - color_frame = frame.get_color_frame() - color_image_raw = np.asanyarray(color_frame.get_data()) - - color_image_processed = self._postprocess_image(color_image_raw, color_mode) + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return color_image_processed + return frame - def _postprocess_image( - self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False - ) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw color frame. Args: image (np.ndarray): The raw image frame (expected RGB format from RealSense). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. @@ -421,9 +434,9 @@ class RealSenseCamera(Camera): `width` and `height`. """ - if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) if depth_frame: @@ -454,7 +467,7 @@ class RealSenseCamera(Camera): On each iteration: 1. Reads a color frame with 500ms timeout - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -462,25 +475,41 @@ class RealSenseCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read(timeout_ms=500) + frame = self._read_from_hardware() + color_frame_raw = frame.get_color_frame() + color_frame = np.asanyarray(color_frame_raw.get_data()) + processed_color_frame = self._postprocess_image(color_frame) + + if self.use_depth: + depth_frame_raw = frame.get_depth_frame() + depth_frame = np.asanyarray(depth_frame_raw.get_data()) + processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True) + + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_color_frame = processed_color_frame + if self.use_depth: + self.latest_depth_frame = processed_depth_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") @@ -498,6 +527,12 @@ class RealSenseCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + # NOTE(Steven): Missing implementation for depth for now def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ @@ -506,6 +541,7 @@ class RealSenseCamera(Camera): This method retrieves the most recent color frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -524,17 +560,16 @@ class RealSenseCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: - frame = self.latest_frame + frame = self.latest_color_frame self.new_frame_event.clear() if frame is None: @@ -542,6 +577,43 @@ class RealSenseCamera(Camera): return frame + # NOTE(Steven): Missing implementation for depth for now + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent (color) frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_color_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera, stops the pipeline, and cleans up resources. @@ -565,4 +637,10 @@ class RealSenseCamera(Camera): self.rs_pipeline = None self.rs_profile = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index 1a4155f4b..a231a582a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -45,6 +45,12 @@ logger = logging.getLogger(__name__) class ZMQCamera(Camera): """ + Manages camera interactions via ZeroMQ for receiving frames from a remote server. + + This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes + incoming JSON messages containing Base64 encoded images. It supports both + synchronous and asynchronous frame reading patterns. + Example usage: ```python from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig @@ -52,7 +58,16 @@ class ZMQCamera(Camera): config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera") camera = ZMQCamera(config) camera.connect() - frame = camera.read() + + # Read 1 frame synchronously (blocking) + color_image = camera.read() + + # Read 1 frame asynchronously (waits for new frame with a timeout) + async_image = camera.async_read() + + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + camera.disconnect() ``` """ @@ -68,14 +83,17 @@ class ZMQCamera(Camera): self.color_mode = config.color_mode self.timeout_ms = config.timeout_ms + # ZMQ Context and Socket self.context: zmq.Context | None = None self.socket: zmq.Socket | None = None self._connected = False + # Threading resources self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() def __str__(self) -> str: @@ -83,10 +101,16 @@ class ZMQCamera(Camera): @property def is_connected(self) -> bool: + """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None def connect(self, warmup: bool = True) -> None: - """Connect to ZMQ camera server.""" + """Connect to ZMQ camera server. + + Args: + warmup (bool): If True, waits for the camera to provide at least one + valid frame before returning. Defaults to True. + """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -103,17 +127,28 @@ class ZMQCamera(Camera): self.socket.connect(f"tcp://{self.server_address}:{self.port}") self._connected = True - # Auto-detect resolution + # Auto-detect resolution if not provided if self.width is None or self.height is None: - h, w = self.read().shape[:2] + # Read directly from hardware because the thread isn't running yet + temp_frame = self._read_from_hardware() + h, w = temp_frame.shape[:2] self.height = h self.width = w - logger.info(f"{self} resolution: {w}x{h}") + logger.info(f"{self} resolution detected: {w}x{h}") + self._start_read_thread() logger.info(f"{self} connected.") if warmup: - time.sleep(0.1) + # Ensure we have captured at least one frame via the thread + start_time = time.time() + while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout + self.async_read(timeout_ms=self.config.warmup_s * 1000) + time.sleep(0.1) + + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") except Exception as e: self._cleanup() @@ -134,12 +169,9 @@ class ZMQCamera(Camera): """ZMQ cameras require manual configuration (server address/port).""" return [] - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + def _read_from_hardware(self) -> NDArray[Any]: """ - Read a single frame from the ZMQ camera. - - Returns: - np.ndarray: Decoded frame (height, width, 3) + Reads a single frame directly from the ZMQ socket. """ if not self.is_connected or self.socket is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -147,6 +179,7 @@ class ZMQCamera(Camera): try: message = self.socket.recv_string() except Exception as e: + # Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import if type(e).__name__ == "Again": raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e raise @@ -176,42 +209,117 @@ class ZMQCamera(Camera): return frame - def _read_loop(self) -> None: - while self.stop_event and not self.stop_event.is_set(): - try: - frame = self.read() - with self.frame_lock: - self.latest_frame = frame - self.new_frame_event.set() - except DeviceNotConnectedError: - break - except TimeoutError: - pass - except Exception as e: - logger.warning(f"Read error: {e}") + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + """ + Reads a single frame synchronously from the camera. - def _start_read_thread(self) -> None: - if self.thread and self.thread.is_alive(): - return - self.stop_event = Event() - self.thread = Thread(target=self._read_loop, daemon=True) - self.thread.start() + This is a blocking call. It waits for the next available frame from the + camera background thread. - def _stop_read_thread(self) -> None: - if self.stop_event: - self.stop_event.set() - if self.thread and self.thread.is_alive(): - self.thread.join(timeout=2.0) - self.thread = None - self.stop_event = None + Returns: + np.ndarray: Decoded frame (height, width, 3) + """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]: - """Read latest frame asynchronously (non-blocking).""" if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - if not self.thread or not self.thread.is_alive(): - self._start_read_thread() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return frame + + def _read_loop(self) -> None: + """ + Internal loop run by the background thread for asynchronous reading. + """ + if self.stop_event is None: + raise RuntimeError(f"{self}: stop_event is not initialized.") + + failure_count = 0 + while not self.stop_event.is_set(): + try: + frame = self._read_from_hardware() + capture_time = time.perf_counter() + + with self.frame_lock: + self.latest_frame = frame + self.latest_timestamp = capture_time + self.new_frame_event.set() + failure_count = 0 + + except DeviceNotConnectedError: + break + except (TimeoutError, Exception) as e: + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Read error: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e + + def _start_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop") + self.thread.start() + time.sleep(0.1) + + def _stop_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: + """ + Reads the latest available frame asynchronously. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms. + + Returns: + np.ndarray: The latest captured frame. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread is not running. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms") @@ -225,11 +333,55 @@ class ZMQCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """Disconnect from ZMQ camera.""" - if not self.is_connected and not self.thread: + if not self.is_connected and self.thread is None: raise DeviceNotConnectedError(f"{self} not connected.") - self._stop_read_thread() + if self.thread is not None: + self._stop_read_thread() + self._cleanup() + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 027ae12b5..4e7732cfc 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig): camera_name: str = "zmq_camera" color_mode: ColorMode = ColorMode.RGB timeout_ms: int = 5000 + warmup_s: int = 1 def __post_init__(self) -> None: if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index eb3df6872..1b30021dd 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -86,8 +86,11 @@ def calibrate(cfg: CalibrateConfig): device = make_teleoperator_from_config(cfg.device) device.connect(calibrate=False) - device.calibrate() - device.disconnect() + + try: + device.calibrate() + finally: + device.disconnect() def main(): diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 5717dffb6..c9a559d07 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -110,25 +110,26 @@ def replay(cfg: ReplayConfig): robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(len(episode_frames)): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - action[name] = action_array[i] + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + action[name] = action_array[i] - robot_obs = robot.get_observation() + robot_obs = robot.get_observation() - processed_action = robot_action_processor((action, robot_obs)) + processed_action = robot_action_processor((action, robot_obs)) - _ = robot.send_action(processed_action) + _ = robot.send_action(processed_action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() def main(): diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py index 3cf3793b6..feb700631 100644 --- a/tests/cameras/test_opencv.py +++ b/tests/cameras/test_opencv.py @@ -20,7 +20,9 @@ # ``` from pathlib import Path +from unittest.mock import patch +import cv2 import numpy as np import pytest @@ -28,6 +30,50 @@ from lerobot.cameras.configs import Cv2Rotation from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +RealVideoCapture = cv2.VideoCapture + + +class MockLoopingVideoCapture: + """ + Wraps the real OpenCV VideoCapture. + Motivation: cv2.VideoCapture(file.png) is only valid for one read. + Strategy: Read the file once & return the cached frame for subsequent reads. + Consequence: No recurrent I/O operations, but we keep the test artifacts simple. + """ + + def __init__(self, *args, **kwargs): + args_clean = [str(a) if isinstance(a, Path) else a for a in args] + self._real_vc = RealVideoCapture(*args_clean, **kwargs) + self._cached_frame = None + + def read(self): + ret, frame = self._real_vc.read() + + if ret: + self._cached_frame = frame + return ret, frame + + if not ret and self._cached_frame is not None: + return True, self._cached_frame.copy() + + return ret, frame + + def __getattr__(self, name): + return getattr(self._real_vc, name) + + +@pytest.fixture(autouse=True) +def patch_opencv_videocapture(): + """ + Automatically patches cv2.VideoCapture for all tests. + """ + module_path = OpenCVCamera.__module__ + target = f"{module_path}.cv2.VideoCapture" + + with patch(target, new=MockLoopingVideoCapture): + yield + + # NOTE(Steven): more tests + assertions? TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png" @@ -43,25 +89,22 @@ def test_abc_implementation(): def test_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - camera.connect(warmup=False) - - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect(warmup=False) + with OpenCVCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): + camera.connect() def test_connect_invalid_camera_path(): config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") + camera = OpenCVCamera(config) with pytest.raises(ConnectionError): @@ -74,27 +117,25 @@ def test_invalid_width_connect(): width=99999, # Invalid width to trigger error height=480, ) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(RuntimeError): camera.connect(warmup=False) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - img = camera.read() - - assert isinstance(img, np.ndarray) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) def test_read_before_connect(): config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(DeviceNotConnectedError): _ = camera.read() @@ -119,32 +160,22 @@ def test_disconnect_before_connect(): @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_async_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - try: + with OpenCVCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + with OpenCVCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -155,6 +186,50 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # ensure at least one fresh frame is captured + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + + camera = OpenCVCamera(config) + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_fourcc_configuration(): """Test FourCC configuration validation and application.""" @@ -181,18 +256,15 @@ def test_fourcc_configuration(): def test_fourcc_with_camera(): """Test FourCC functionality with actual camera connection.""" - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG") - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG", warmup_s=0) # Connect should work with MJPG specified - camera.connect(warmup=False) - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected - # Read should work normally - img = camera.read() - assert isinstance(img, np.ndarray) - - camera.disconnect() + # Read should work normally + img = camera.read() + assert isinstance(img, np.ndarray) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) @@ -211,18 +283,16 @@ def test_rotation(rotation, index_or_path): dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) original_width, original_height = map(int, dimensions.split("x")) - config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation, warmup_s=0) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == original_height - assert camera.height == original_width - assert img.shape[:2] == (original_width, original_height) - else: - assert camera.width == original_width - assert camera.height == original_height - assert img.shape[:2] == (original_height, original_width) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == original_height + assert camera.height == original_width + assert img.shape[:2] == (original_width, original_height) + else: + assert camera.width == original_width + assert camera.height == original_height + assert img.shape[:2] == (original_height, original_width) diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py index 14774bf38..2aebfdf0a 100644 --- a/tests/cameras/test_reachy2_camera.py +++ b/tests/cameras/test_reachy2_camera.py @@ -150,6 +150,44 @@ def test_async_read_before_connect(camera): _ = camera.async_read() +def test_read_latest(camera): + camera.connect() + + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(camera): + # camera fixture yields an unconnected camera instance + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(camera): + camera.connect() + + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(camera): + camera.connect() + + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_wrong_camera_name(): with pytest.raises(ValueError): _ = Reachy2CameraConfig(name="wrong-name", image_type="left") diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py index fb9912257..1deb73f05 100644 --- a/tests/cameras/test_realsense.py +++ b/tests/cameras/test_realsense.py @@ -62,19 +62,15 @@ def test_abc_implementation(): def test_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) - camera.connect(warmup=False) - assert camera.is_connected + with RealSenseCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - with pytest.raises(DeviceAlreadyConnectedError): + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): camera.connect(warmup=False) @@ -96,12 +92,10 @@ def test_invalid_width_connect(): def test_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) # TODO(Steven): Fix this test for the latest version of pyrealsense2. @@ -142,32 +136,21 @@ def test_disconnect_before_connect(): def test_async_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) - try: + with RealSenseCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -178,6 +161,47 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == img.shape + + +def test_read_latest_high_frequency(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + # prime with one read to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_too_old(): + config = RealSenseCameraConfig(serial_number_or_name="042") + + with RealSenseCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + @pytest.mark.parametrize( "rotation", [ @@ -189,18 +213,16 @@ def test_async_read_before_connect(): ids=["no_rot", "rot90", "rot180", "rot270"], ) def test_rotation(rotation): - config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == 480 - assert camera.height == 640 - assert img.shape[:2] == (640, 480) - else: - assert camera.width == 640 - assert camera.height == 480 - assert img.shape[:2] == (480, 640) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == 480 + assert camera.height == 640 + assert img.shape[:2] == (640, 480) + else: + assert camera.width == 640 + assert camera.height == 480 + assert img.shape[:2] == (480, 640) From 04cbf669cf0565950f8ba66e8a03a66bd8f20d7a Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 12:23:22 +0100 Subject: [PATCH 10/43] fix(sac): make temperature a property to fix checkpoint resume bug (#2877) * fix(sac): make temperature a property to fix checkpoint resume bug Temperature was stored as a plain float and not restored after loading a checkpoint, causing incorrect loss computations until update_temperature() was called. Changed to a property that always computes from log_alpha, ensuring correct behavior after checkpoint loading. * simplify docstrings --- src/lerobot/policies/sac/modeling_sac.py | 11 ++++++----- src/lerobot/rl/learner.py | 3 --- tests/policies/test_sac_policy.py | 3 ++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index c7c6798ed..d5dd71a48 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -239,8 +239,10 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def update_temperature(self): - self.temperature = self.log_alpha.exp().item() + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() def compute_loss_critic( self, @@ -457,11 +459,10 @@ class SACPolicy( dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 - def _init_temperature(self): - """Set up temperature parameter and initial log_alpha.""" + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha).""" temp_init = self.config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) - self.temperature = self.log_alpha.exp().item() class SACObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index abc5c9504..ee09ac9ac 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -545,9 +545,6 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature - # Update temperature - policy.update_temperature() - # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 8576883bd..6fad2979e 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy(): def test_sac_policy_update_temperature(): + """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) policy = SACPolicy(config=config) assert policy.temperature == pytest.approx(1.0) policy.log_alpha.data = torch.tensor([math.log(0.1)]) - policy.update_temperature() + # Temperature property automatically reflects log_alpha changes assert policy.temperature == pytest.approx(0.1) From ec04b7ce3aca23491e42232c1ae723bb4b981993 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 13:19:42 +0100 Subject: [PATCH 11/43] Feat(dataset_tools.py) Add modify tasks tool (#2875) * feat(datasets): add modify_tasks function for in-place task editing Add a new utility function to modify tasks in LeRobotDataset in-place. This allows users to: - Set a single task for all episodes - Set specific tasks for individual episodes - Combine a default task with per-episode overrides * feat(edit-dataset): add CLI support for modify_tasks operation Integrate the modify_tasks function into lerobot_edit_dataset CLI. Users can now modify dataset tasks via command line: Supports setting a default task, per-episode tasks, or both combined. * test(datasets): add tests for modify_tasks function Add comprehensive test coverage for the modify_tasks utility: - Single task for all episodes - Episode-specific task assignment - Default task with per-episode overrides - Error handling for missing/invalid arguments - Verification of task_index correctness - In-place modification behavior - Metadata preservation * respond to copilot review --- src/lerobot/datasets/dataset_tools.py | 126 +++++++++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 82 +++++++++- tests/datasets/test_dataset_tools.py | 169 ++++++++++++++++++++ 3 files changed, 374 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index e2928e2a6..123d455c6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024 BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB +def modify_tasks( + dataset: LeRobotDataset, + new_task: str | None = None, + episode_tasks: dict[int, str] | None = None, +) -> LeRobotDataset: + """Modify tasks in a LeRobotDataset. + + This function allows you to either: + 1. Set a single task for the entire dataset (using `new_task`) + 2. Set specific tasks for specific episodes (using `episode_tasks`) + + You can combine both: `new_task` sets the default, and `episode_tasks` overrides + specific episodes. + + The dataset is modified in-place, updating only the task-related files: + - meta/tasks.parquet + - data/**/*.parquet (task_index column) + - meta/episodes/**/*.parquet (tasks column) + - meta/info.json (total_tasks) + + Args: + dataset: The source LeRobotDataset to modify. + new_task: A single task string to apply to all episodes. If None and episode_tasks + is also None, raises an error. + episode_tasks: Optional dict mapping episode indices to their task strings. + Overrides `new_task` for specific episodes. + + + Examples: + Set a single task for all episodes: + dataset = modify_tasks(dataset, new_task="Pick up the cube") + + Set different tasks for specific episodes: + dataset = modify_tasks( + dataset, + episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"} + ) + + Set a default task with overrides: + dataset = modify_tasks( + dataset, + new_task="Default task", + episode_tasks={5: "Special task for episode 5"} + ) + """ + if new_task is None and episode_tasks is None: + raise ValueError("Must specify at least one of new_task or episode_tasks") + + if episode_tasks is not None: + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_tasks.keys()) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + # Ensure episodes metadata is loaded + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.root) + + # Build the mapping from episode index to task string + episode_to_task: dict[int, str] = {} + for ep_idx in range(dataset.meta.total_episodes): + if episode_tasks and ep_idx in episode_tasks: + episode_to_task[ep_idx] = episode_tasks[ep_idx] + elif new_task is not None: + episode_to_task[ep_idx] = new_task + else: + # Keep original task if not overridden and no default provided + original_tasks = dataset.meta.episodes[ep_idx]["tasks"] + if not original_tasks: + raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided") + episode_to_task[ep_idx] = original_tasks[0] + + # Collect all unique tasks and create new task mapping + unique_tasks = sorted(set(episode_to_task.values())) + new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks) + task_to_index = {task: idx for idx, task in enumerate(unique_tasks)} + + logging.info(f"Modifying tasks in {dataset.repo_id}") + logging.info(f"New tasks: {unique_tasks}") + + root = dataset.root + + # Update data files - modify task_index column + logging.info("Updating data files...") + data_dir = root / DATA_DIR + + for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"): + df = pd.read_parquet(parquet_path) + + # Build a mapping from episode_index to new task_index for rows in this file + episode_indices_in_file = df["episode_index"].unique() + ep_to_new_task_idx = { + ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file + } + + # Update task_index column + df["task_index"] = df["episode_index"].map(ep_to_new_task_idx) + df.to_parquet(parquet_path, index=False) + + # Update episodes metadata - modify tasks column + logging.info("Updating episodes metadata...") + episodes_dir = root / "meta" / "episodes" + + for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"): + df = pd.read_parquet(parquet_path) + + # Update tasks column + df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]]) + df.to_parquet(parquet_path, index=False) + + # Write new tasks.parquet + write_tasks(new_task_df, root) + + # Update info.json + dataset.meta.info["total_tasks"] = len(unique_tasks) + write_info(dataset.meta.info, root) + + # Reload metadata to reflect changes + dataset.meta.tasks = new_task_df + dataset.meta.episodes = load_episodes(root) + + logging.info(f"Tasks: {unique_tasks}") + + return dataset + + def convert_image_to_video_dataset( dataset: LeRobotDataset, output_dir: Path, diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 4ba6ce44f..2ca9c520d 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,7 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -remove features, and convert image datasets to video format. +remove features, modify tasks, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -66,6 +66,25 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Modify tasks - set a single task for all episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Pick up the cube and place it" + +Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' + +Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Default task" \ + --operation.episode_tasks '{"5": "Special task for episode 5"}' + Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ @@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, delete_episodes, merge_datasets, + modify_tasks, remove_feature, split_dataset, ) @@ -132,6 +152,13 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class ModifyTasksConfig: + type: str = "modify_tasks" + new_task: str | None = None + episode_tasks: dict[str, str] | None = None + + @dataclass class ConvertImageToVideoConfig: type: str = "convert_image_to_video" @@ -151,7 +178,12 @@ class ConvertImageToVideoConfig: class EditDatasetConfig: repo_id: str operation: ( - DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + DeleteEpisodesConfig + | SplitConfig + | MergeConfig + | RemoveFeatureConfig + | ModifyTasksConfig + | ConvertImageToVideoConfig ) root: str | None = None new_repo_id: str | None = None @@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_modify_tasks(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, ModifyTasksConfig): + raise ValueError("Operation config must be ModifyTasksConfig") + + new_task = cfg.operation.new_task + episode_tasks_raw = cfg.operation.episode_tasks + + if new_task is None and episode_tasks_raw is None: + raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation") + + # Warn about in-place modification behavior + if cfg.new_repo_id is not None: + logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.") + + # Convert episode_tasks keys from string to int if needed (CLI passes strings) + episode_tasks: dict[int, str] | None = None + if episode_tasks_raw is not None: + episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()} + + logging.info(f"Modifying tasks in {cfg.repo_id}") + if new_task: + logging.info(f" Default task: '{new_task}'") + if episode_tasks: + logging.info(f" Episode-specific tasks: {episode_tasks}") + + modified_dataset = modify_tasks( + dataset, + new_task=new_task, + episode_tasks=episode_tasks, + ) + + logging.info(f"Dataset modified at {dataset.root}") + logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + modified_dataset.push_to_hub() + + def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() @@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "modify_tasks": + handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" + f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" ) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 35a369de9..1de199630 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import ( delete_episodes, merge_datasets, modify_features, + modify_tasks, remove_feature, split_dataset, ) @@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features +def test_modify_tasks_single_task_for_all(sample_dataset): + """Test setting a single task for all episodes.""" + new_task = "Pick up the cube and place it" + + modified_dataset = modify_tasks(sample_dataset, new_task=new_task) + + # Verify all episodes have the new task + assert len(modified_dataset.meta.tasks) == 1 + assert new_task in modified_dataset.meta.tasks.index + + # Verify task_index is 0 for all frames (only one task) + for i in range(len(modified_dataset)): + item = modified_dataset[i] + assert item["task_index"].item() == 0 + assert item["task"] == new_task + + +def test_modify_tasks_episode_specific(sample_dataset): + """Test setting different tasks for specific episodes.""" + episode_tasks = { + 0: "Task A", + 1: "Task B", + 2: "Task A", + 3: "Task C", + 4: "Task B", + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify correct number of unique tasks + unique_tasks = set(episode_tasks.values()) + assert len(modified_dataset.meta.tasks) == len(unique_tasks) + + # Verify each episode has the correct task + for ep_idx, expected_task in episode_tasks.items(): + ep_data = modified_dataset.meta.episodes[ep_idx] + assert ep_data["tasks"][0] == expected_task + + +def test_modify_tasks_default_with_overrides(sample_dataset): + """Test setting a default task with specific overrides.""" + default_task = "Default task" + override_task = "Special task" + episode_tasks = {2: override_task, 4: override_task} + + modified_dataset = modify_tasks( + sample_dataset, + new_task=default_task, + episode_tasks=episode_tasks, + ) + + # Verify correct number of unique tasks + assert len(modified_dataset.meta.tasks) == 2 + assert default_task in modified_dataset.meta.tasks.index + assert override_task in modified_dataset.meta.tasks.index + + # Verify episodes have correct tasks + for ep_idx in range(5): + ep_data = modified_dataset.meta.episodes[ep_idx] + if ep_idx in episode_tasks: + assert ep_data["tasks"][0] == override_task + else: + assert ep_data["tasks"][0] == default_task + + +def test_modify_tasks_no_task_specified(sample_dataset): + """Test error when no task is specified.""" + with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"): + modify_tasks(sample_dataset) + + +def test_modify_tasks_invalid_episode_indices(sample_dataset): + """Test error with invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"}) + + +def test_modify_tasks_updates_info_json(sample_dataset): + """Test that total_tasks is updated in info.json.""" + episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify total_tasks is updated + assert modified_dataset.meta.total_tasks == 3 + + +def test_modify_tasks_preserves_other_metadata(sample_dataset): + """Test that modifying tasks preserves other metadata.""" + original_frames = sample_dataset.meta.total_frames + original_episodes = sample_dataset.meta.total_episodes + original_fps = sample_dataset.meta.fps + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify other metadata is preserved + assert modified_dataset.meta.total_frames == original_frames + assert modified_dataset.meta.total_episodes == original_episodes + assert modified_dataset.meta.fps == original_fps + + +def test_modify_tasks_task_index_correct(sample_dataset): + """Test that task_index values are correct in data files.""" + # Create tasks that will have predictable indices (sorted alphabetically) + episode_tasks = { + 0: "Alpha task", # Will be index 0 + 1: "Beta task", # Will be index 1 + 2: "Alpha task", # Will be index 0 + 3: "Gamma task", # Will be index 2 + 4: "Beta task", # Will be index 1 + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify task indices are correct + task_to_expected_idx = { + "Alpha task": 0, + "Beta task": 1, + "Gamma task": 2, + } + + for i in range(len(modified_dataset)): + item = modified_dataset[i] + ep_idx = item["episode_index"].item() + expected_task = episode_tasks[ep_idx] + expected_idx = task_to_expected_idx[expected_task] + assert item["task_index"].item() == expected_idx + assert item["task"] == expected_task + + +def test_modify_tasks_in_place(sample_dataset): + """Test that modify_tasks modifies the dataset in-place.""" + original_root = sample_dataset.root + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify same instance is returned and root is unchanged + assert modified_dataset is sample_dataset + assert modified_dataset.root == original_root + + +def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): + """Test that original tasks are kept when using episode_tasks without new_task.""" + from lerobot.datasets.utils import load_episodes + + # Ensure episodes metadata is loaded + if sample_dataset.meta.episodes is None: + sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root) + + # Get original tasks for episodes not being overridden + original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0] + original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0] + + # Only override episodes 2, 3, 4 + episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify original tasks are kept for episodes 0 and 1 + assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0 + assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1 + + # Verify new tasks for overridden episodes + assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A" + assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B" + assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A" + + def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset From 55c0471db9e440e99e801e2e67d645ecd7fdb9d5 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:57:56 +0100 Subject: [PATCH 12/43] docs(cameras): revising and improving docs on cameras (#2878) * docs(cameras): revising and improving docs on cameras * resolving copilot comments --- docs/source/_toctree.yml | 6 +- docs/source/cameras.mdx | 176 +++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 83 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index eb97117af..98417f134 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: cameras - title: Cameras - local: bring_your_own_policies title: Bring Your Own Policies - local: integrate_hardware @@ -108,6 +106,10 @@ - local: phone_teleop title: Phone title: "Teleoperators" +- sections: + - local: cameras + title: Cameras + title: "Sensors" - sections: - local: torch_accelerators title: PyTorch accelerators diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 5c35be0ba..8af0f5ae5 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -1,12 +1,22 @@ # Cameras -LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). +LeRobot offers multiple options for video capture: -### Finding your camera +| Class | Supported Cameras | +| ----------------- | ----------------------------------- | +| `OpenCVCamera` | Phone, built-in laptop, USB webcams | +| `ZMQCamera` | Network-connected cameras | +| `RealSenseCamera` | Intel RealSense (with depth) | +| `Reachy2Camera` | Reachy 2 robot cameras | -To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system. +> [!TIP] +> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). -To find the camera indices of the cameras plugged into your system, run the following script: +### Find your camera + +Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices. + +`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system. ```bash lerobot-find-cameras opencv # or realsense for Intel Realsense cameras @@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras The output will look something like this if you have two cameras connected: -``` +```bash --- Detected Cameras --- Camera #0: Name: OpenCV Camera @ 0 @@ -33,13 +43,37 @@ Camera #0: > [!WARNING] > When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable. -## Use Cameras +`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings. -Below are two examples, demonstrating how to work with the API. +## Use cameras -- **Asynchronous frame capture** using an OpenCV-based camera +### Frame access modes + +All camera classes implement three access modes for capturing frames: + +| Method | Behavior | Blocks? | Best For | +| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- | +| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture | +| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS | +| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring | + +### Usage examples + +The following examples show how to use the camera API to configure and capture frames from different camera types. + +- **Blocking and non-blocking frame capture** using an OpenCV-based camera - **Color and depth capture** using an Intel RealSense camera +> [!WARNING] +> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup: +> +> ```python +> with OpenCVCamera(config) as camera: +> ... +> ``` +> +> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter. + @@ -60,16 +94,30 @@ config = OpenCVCameraConfig( ) # Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default). -camera = OpenCVCamera(config) -camera.connect() +with OpenCVCamera(config) as camera: + + # Read a frame synchronously — blocks until hardware delivers a new frame + frame = camera.read() + print(f"read() call returned frame with shape:", frame.shape) + + # Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one + try: + for i in range(10): + frame = camera.async_read(timeout_ms=200) + print(f"async_read call returned frame {i} with shape:", frame.shape) + except TimeoutError as e: + print(f"No frame received within timeout: {e}") + + # Instantly return a frame - returns the most recent frame captured by the camera + try: + initial_frame = camera.read_latest(max_age_ms=1000) + for i in range(10): + frame = camera.read_latest(max_age_ms=1000) + print(f"read_latest call returned frame {i} with shape:", frame.shape) + print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}") + except TimeoutError as e: + print(f"Frame too old: {e}") -# Read frames asynchronously in a loop via `async_read(timeout_ms)` -try: - for i in range(10): - frame = camera.async_read(timeout_ms=200) - print(f"Async frame {i} shape:", frame.shape) -finally: - camera.disconnect() ``` @@ -111,10 +159,10 @@ finally: -## Use your phone +## Use your phone's camera - + To use your iPhone as a camera on macOS, enable the Continuity Camera feature: @@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature: For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac). -Your iPhone should be detected automatically when running the camera setup script in the next section. - - + -If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera +If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera. -1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: +1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with: - -```python +```bash sudo apt install v4l2loopback-dkms v4l-utils ``` - -2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android. -3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): +2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android. +3. _Download and install [OBS Studio](https://obsproject.com)_. +4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_. +5. _Start OBS Studio_. - -```python -flatpak install flathub com.obsproject.Studio -``` - - -4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with: - - -```python -flatpak install flathub com.obsproject.Studio.Plugin.DroidCam -``` - - -5. _Start OBS Studio_. Launch with: - - -```python -flatpak run com.obsproject.Studio -``` - - -6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. -7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. +6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks. +7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it. 8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). -9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices: +9. _Verify the virtual camera setup and resolution_. + - **Linux**: Use `v4l2-ctl` to list devices and check resolution: + ```bash + v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path + v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path + ``` + You should see `VirtualCam` listed and resolution `640x480`. + - **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input. + - **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify. - -```python -v4l2-ctl --list-devices -``` - +
+Troubleshooting -You should see an entry like: +> The virtual camera resolution is incorrect. -``` -VirtualCam (platform:v4l2loopback-000): -/dev/video1 -``` +Delete the virtual camera source and recreate it. The resolution cannot be changed after creation. -10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. +> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080. - -```python -v4l2-ctl -d /dev/video1 --get-fmt-video -``` - +This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`. -You should see an entry like: - -``` ->>> Format Video Capture: ->>> Width/Height : 640/480 ->>> Pixel Format : 'YUYV' (YUYV 4:2:2) -``` - -Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed. - -If everything is set up correctly, you can proceed with the rest of the tutorial. +
+ +If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`. From 5c6182176f31996fb5d0c51f88a1bc59457ba7a6 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:58:13 +0100 Subject: [PATCH 13/43] fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879) Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com> --- src/lerobot/cameras/zmq/camera_zmq.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index a231a582a..f29e16a28 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -166,8 +166,10 @@ class ZMQCamera(Camera): @staticmethod def find_cameras() -> list[dict[str, Any]]: - """ZMQ cameras require manual configuration (server address/port).""" - return [] + """ + Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port). + """ + raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.") def _read_from_hardware(self) -> NDArray[Any]: """ From b18cef2e260a80db6cbe2327140950964c797b46 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 30 Jan 2026 10:29:37 -0800 Subject: [PATCH 14/43] feat(dataset): add subtask support (#2860) * add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs --- docs/source/_toctree.yml | 2 + docs/source/dataset_subtask.mdx | 278 +++++++++++ src/lerobot/datasets/lerobot_dataset.py | 9 + src/lerobot/datasets/utils.py | 9 + src/lerobot/processor/converters.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 46 ++ src/lerobot/utils/constants.py | 3 + tests/datasets/test_subtask_dataset.py | 190 ++++++++ tests/processor/test_tokenizer_processor.py | 465 ++++++++++++++++++- 9 files changed, 1003 insertions(+), 2 deletions(-) create mode 100644 docs/source/dataset_subtask.mdx create mode 100644 tests/datasets/test_subtask_dataset.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 98417f134..d61aac9c1 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets - local: using_dataset_tools title: Using the Dataset Tools + - local: dataset_subtask + title: Using Subtasks in the Dataset title: "Datasets" - sections: - local: act diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx new file mode 100644 index 000000000..beb5d80bd --- /dev/null +++ b/docs/source/dataset_subtask.mdx @@ -0,0 +1,278 @@ +# Using Subtasks in LeRobot Datasets + +Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for: + +- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time +- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models) +- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps + +LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks. + +## What are Subtasks? + +While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps: + +1. "Approach the apple" +2. "Grasp the apple" +3. "Lift the apple" +4. "Move to basket" +5. "Release the apple" + +Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages. + +An overview of subtask annotation showing how frames are labeled with intermediate subtask stages + +

+ Figure: Overview of subtask annotation. +

+ +**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022. + +## Dataset Structure + +Subtask information is stored in the dataset metadata: + +``` +my-dataset/ +├── data/ +│ └── ... +├── meta/ +│ ├── info.json +│ ├── stats.json +│ ├── tasks.parquet +│ ├── subtasks.parquet # Subtask index → subtask string mapping +│ └── episodes/ +│ └── ... +└── videos/ + └── ... +``` + +### Subtasks Parquet File + +The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions: + +| subtask_index | subtask (index column) | +| ------------- | ---------------------- | +| 0 | "Approach the apple" | +| 1 | "Grasp the apple" | +| 2 | "Lift the apple" | +| ... | ... | + +### Frame-Level Annotations + +Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file: + +```python +# Example frame data in the parquet file +{ + "index": 42, + "timestamp": 1.4, + "episode_index": 0, + "task_index": 0, + "subtask_index": 2, # References "Lift the apple" + "observation.state": [...], + "action": [...], +} +``` + +## Annotating Datasets with Subtasks + +We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks: + +**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)** + +After completing your annotation: + +1. Click "Push to Hub" to upload your annotated dataset +2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate) + +## Loading Datasets with Subtasks + +When you load a dataset with subtask annotations, the subtask information is automatically available: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load a dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Access a sample +sample = dataset[100] + +# The sample includes both task and subtask information +print(sample["task"]) # "Collect the fruit" +print(sample["subtask"]) # "Grasp the apple" +print(sample["task_index"]) # tensor(0) +print(sample["subtask_index"]) # tensor(2) +``` + +### Checking for Subtask Support + +You can check if a dataset has subtask annotations: + +```python +# Check if subtasks are available +has_subtasks = ( + "subtask_index" in dataset.features + and dataset.meta.subtasks is not None +) + +if has_subtasks: + print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks") + print("Subtasks:", list(dataset.meta.subtasks.index)) +``` + +## Using Subtasks for Training + +### With the Tokenizer Processor + +The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: + +```python +from lerobot.processor.tokenizer_processor import TokenizerProcessor +from lerobot.processor.pipeline import ProcessorPipeline + +# Create a tokenizer processor +tokenizer_processor = TokenizerProcessor( + tokenizer_name_or_path="google/paligemma-3b-pt-224", + padding="max_length", + max_length=64, +) + +# The processor will automatically tokenize subtasks if present in the batch +# and add them to the observation under: +# - "observation.subtask.tokens" +# - "observation.subtask.attention_mask" +``` + +When subtasks are available in the batch, the tokenizer processor adds: + +- `observation.subtask.tokens`: Tokenized subtask text +- `observation.subtask.attention_mask`: Attention mask for the subtask tokens + +### DataLoader with Subtasks + +```python +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=16, + shuffle=True, +) + +for batch in dataloader: + # Access subtask information in the batch + subtasks = batch["subtask"] # List of subtask strings + subtask_indices = batch["subtask_index"] # Tensor of subtask indices + + # Use for training hierarchical policies or reward models + print(f"Batch subtasks: {set(subtasks)}") +``` + +## Example Datasets with Subtask Annotations + +Try loading a dataset with subtask annotations: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Example dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Explore the subtasks +print("Available subtasks:") +for subtask_name in dataset.meta.subtasks.index: + print(f" - {subtask_name}") + +# Get subtask distribution +subtask_counts = {} +for i in range(len(dataset)): + sample = dataset[i] + subtask = sample["subtask"] + subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1 + +print("\nSubtask distribution:") +for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]): + print(f" {subtask}: {count} frames") +``` + +## Use Cases + +### 1. Hierarchical Policy Training + +Train policies that predict both actions and current subtask: + +```python +class HierarchicalPolicy(nn.Module): + def __init__(self, num_subtasks): + super().__init__() + self.action_head = nn.Linear(hidden_dim, action_dim) + self.subtask_head = nn.Linear(hidden_dim, num_subtasks) + + def forward(self, observations): + features = self.encoder(observations) + actions = self.action_head(features) + subtask_logits = self.subtask_head(features) + return actions, subtask_logits +``` + +### 2. Stage-Aware Reward Modeling (SARM) + +Build reward models that understand task progression: + +```python +# SARM predicts: +# - Stage: Which subtask is being executed (discrete) +# - Progress: How far along the subtask (continuous 0-1) + +class SARMRewardModel(nn.Module): + def forward(self, observations): + features = self.encoder(observations) + stage_logits = self.stage_classifier(features) + progress = self.progress_regressor(features) + return stage_logits, progress +``` + +### 3. Progress Visualization + +Monitor robot execution by tracking subtask progression: + +```python +def visualize_execution(model, observations): + for t, obs in enumerate(observations): + action, subtask_logits = model(obs) + predicted_subtask = subtask_names[subtask_logits.argmax()] + print(f"t={t}: Executing '{predicted_subtask}'") +``` + +## API Reference + +### LeRobotDataset Properties + +| Property | Type | Description | +| --------------------------- | ---------------------- | ------------------------------------------ | +| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices | +| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present | + +### Sample Keys + +When subtasks are available, each sample includes: + +| Key | Type | Description | +| --------------- | -------------- | ------------------------------------ | +| `subtask_index` | `torch.Tensor` | Integer index of the current subtask | +| `subtask` | `str` | Natural language subtask description | + +## Related Resources + +- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation +- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool +- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6798e7fd7..36bffa190 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -57,6 +57,7 @@ from lerobot.datasets.utils import ( load_info, load_nested_dataset, load_stats, + load_subtasks, load_tasks, update_chunk_file_indices, validate_episode_buffer, @@ -162,6 +163,7 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -518,6 +520,7 @@ class LeRobotDatasetMetadata: _validate_feature_names(features) obj.tasks = None + obj.subtasks = None obj.episodes = None obj.stats = None obj.info = create_empty_dataset_info( @@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self.features and self.meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name + return item def __repr__(self): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed678af6e..321ecedd5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -60,6 +60,7 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" @@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: return tasks +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. This function writes episode-level metadata to a single parquet file. diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 4f9485fee..18c7b0220 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} + return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5cd1bebb0..df559555a 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -34,6 +34,8 @@ from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, OBS_LANGUAGE_TOKENS, ) from lerobot.utils.import_utils import _transformers_available @@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get("subtask") + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. @@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + tokenized_subtask = self._tokenize_text(subtask) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_subtask = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask.items() + } + + # Add tokenized subtask to the observation + new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to( + dtype=torch.bool + ) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43a61b4f7..ecd54844c 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" +OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "." diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py new file mode 100644 index 000000000..f80a6c72d --- /dev/null +++ b/tests/datasets/test_subtask_dataset.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for subtask functionality in LeRobotDataset. + +These tests verify that: +- Subtask information is correctly loaded from datasets that have subtask data +- The __getitem__ method correctly adds subtask strings to returned items +- Subtask handling gracefully handles missing data +""" + +import pandas as pd +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +class TestSubtaskDataset: + """Tests for subtask handling in LeRobotDataset.""" + + @pytest.fixture + def subtask_dataset(self): + """Load the test subtask dataset from the hub.""" + # Use lerobot/pusht-subtask dataset with episode 1 + return LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + + def test_subtask_dataset_loads(self, subtask_dataset): + """Test that the subtask dataset loads successfully.""" + assert subtask_dataset is not None + assert len(subtask_dataset) > 0 + + def test_subtask_metadata_loaded(self, subtask_dataset): + """Test that subtask metadata is loaded when present in dataset.""" + # The dataset should have subtasks metadata loaded + assert subtask_dataset.meta.subtasks is not None + assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame) + + def test_subtask_index_in_features(self, subtask_dataset): + """Test that subtask_index is a feature when dataset has subtasks.""" + assert "subtask_index" in subtask_dataset.features + + def test_getitem_returns_subtask_string(self, subtask_dataset): + """Test that __getitem__ correctly adds subtask string to returned item.""" + item = subtask_dataset[0] + + # Subtask should be present in the returned item + assert "subtask" in item + assert isinstance(item["subtask"], str) + assert len(item["subtask"]) > 0 # Should not be empty + + def test_getitem_has_subtask_index(self, subtask_dataset): + """Test that __getitem__ includes subtask_index.""" + item = subtask_dataset[0] + + assert "subtask_index" in item + assert isinstance(item["subtask_index"], torch.Tensor) + + def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset): + """Test that subtask_index correctly maps to a subtask in metadata.""" + item = subtask_dataset[0] + + subtask_idx = item["subtask_index"].item() + subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name + + assert item["subtask"] == subtask_from_metadata + + def test_all_items_have_subtask(self, subtask_dataset): + """Test that all items in the dataset have subtask information.""" + for i in range(min(len(subtask_dataset), 5)): # Check first 5 items + item = subtask_dataset[i] + assert "subtask" in item + assert isinstance(item["subtask"], str) + + def test_task_and_subtask_coexist(self, subtask_dataset): + """Test that both task and subtask are present in returned items.""" + item = subtask_dataset[0] + + # Both task and subtask should be present + assert "task" in item + assert "subtask" in item + assert isinstance(item["task"], str) + assert isinstance(item["subtask"], str) + + +class TestSubtaskDatasetMissing: + """Tests for graceful handling when subtask data is missing.""" + + @pytest.fixture + def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory): + """Create a dataset without subtask information.""" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features) + + # Add some frames and save + for _ in range(5): + dataset.add_frame({"state": torch.randn(2), "task": "Test task"}) + dataset.save_episode() + dataset.finalize() + + # Reload the dataset + return LeRobotDataset(dataset.repo_id, root=dataset.root) + + def test_no_subtask_in_features(self, dataset_without_subtasks): + """Test that subtask_index is not in features when not provided.""" + assert "subtask_index" not in dataset_without_subtasks.features + + def test_getitem_without_subtask(self, dataset_without_subtasks): + """Test that __getitem__ works when subtask is not present.""" + item = dataset_without_subtasks[0] + + # Item should still be retrievable + assert item is not None + assert "state" in item + assert "task" in item + + # Subtask should NOT be present + assert "subtask" not in item + + def test_subtasks_metadata_is_none(self, dataset_without_subtasks): + """Test that subtasks metadata is None when not present.""" + assert dataset_without_subtasks.meta.subtasks is None + + +class TestSubtaskEdgeCases: + """Edge case tests for subtask handling.""" + + def test_subtask_with_multiple_episodes(self): + """Test subtask handling with multiple episodes if available.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + # Check first and last items have valid subtasks + first_item = dataset[0] + last_item = dataset[len(dataset) - 1] + + assert "subtask" in first_item + assert "subtask" in last_item + assert isinstance(first_item["subtask"], str) + assert isinstance(last_item["subtask"], str) + + def test_subtask_index_consistency(self): + """Test that same subtask_index returns same subtask string.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + if len(dataset) < 2: + pytest.skip("Dataset too small for this test") + + # Collect subtask_index to subtask mappings + subtask_map = {} + for i in range(min(len(dataset), 10)): + item = dataset[i] + idx = item["subtask_index"].item() + subtask = item["subtask"] + + if idx in subtask_map: + # Same index should always return same subtask + assert subtask_map[idx] == subtask, ( + f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'" + ) + else: + subtask_map[idx] = subtask diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index d6f87f567..64cc8aac8 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -27,7 +27,14 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGE, + OBS_LANGUAGE, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, + OBS_STATE, +) from tests.utils import require_package @@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario(): # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) assert tokens.shape == (10,) # MockTokenizer behavior for single string in list assert attention_mask.shape == (10,) + + +# ============================================================================= +# Tests for get_subtask method +# ============================================================================= + + +@require_package("transformers") +def test_get_subtask_missing_key(): + """Test get_subtask returns None when subtask key is missing from complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No "subtask" key + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_value(): + """Test get_subtask returns None when subtask value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_complementary_data(): + """Test get_subtask returns None when complementary_data is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data=None, # No complementary data + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_string(): + """Test get_subtask returns list with single string when subtask is a string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the cube"}, + ) + + result = processor.get_subtask(transition) + assert result == ["pick up the cube"] + assert isinstance(result, list) + assert len(result) == 1 + + +@require_package("transformers") +def test_get_subtask_list_of_strings(): + """Test get_subtask returns the list when subtask is already a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + subtask_list = ["pick up", "move to target", "place down"] + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": subtask_list}, + ) + + result = processor.get_subtask(transition) + assert result == subtask_list + assert isinstance(result, list) + assert len(result) == 3 + + +@require_package("transformers") +def test_get_subtask_unsupported_type_integer(): + """Test get_subtask returns None when subtask is an unsupported type (integer).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_mixed_list(): + """Test get_subtask returns None when subtask is a list with mixed types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_dict(): + """Test get_subtask returns None when subtask is a dictionary.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": {"key": "value"}}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_empty_string(): + """Test get_subtask with empty string returns list with empty string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ""}, + ) + + result = processor.get_subtask(transition) + assert result == [""] + + +@require_package("transformers") +def test_get_subtask_empty_list(): + """Test get_subtask with empty list returns empty list.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": []}, + ) + + result = processor.get_subtask(transition) + assert result == [] + + +# ============================================================================= +# Tests for subtask tokenization in observation method +# ============================================================================= + + +@require_package("transformers") +def test_subtask_tokenization_when_present(): + """Test that subtask is tokenized and added to observation when present.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert isinstance(subtask_tokens, torch.Tensor) + assert isinstance(subtask_attention_mask, torch.Tensor) + assert subtask_tokens.shape == (8,) + assert subtask_attention_mask.shape == (8,) + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_none(): + """Test that subtask tokens are NOT added to observation when subtask is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No subtask + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_subtask_value_is_none(): + """Test that subtask tokens are NOT added when subtask value is explicitly None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + +@require_package("transformers") +def test_subtask_tokenization_list_of_strings(): + """Test subtask tokenization with list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["pick up", "place down"]}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure for batch + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert subtask_attention_mask.shape == (2, 8) + + +@require_package("transformers") +def test_subtask_tokenization_device_cpu(): + """Test that subtask tokens are on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {OBS_STATE: torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CPU + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cpu" + assert subtask_attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_subtask_tokenization_device_cuda(): + """Test that subtask tokens are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CUDA + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cuda" + assert subtask_attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_subtask_tokenization_preserves_other_observation_data(): + """Test that subtask tokenization preserves other observation data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + original_state = torch.tensor([1.0, 2.0, 3.0]) + transition = create_transition( + observation={"state": original_state.clone()}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Check that original observation data is preserved + assert torch.equal(observation["state"], original_state) + + # Check that both task and subtask tokens are present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + +@require_package("transformers") +def test_subtask_attention_mask_dtype(): + """Test that subtask attention mask has correct dtype (bool).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_deterministic(): + """Test that subtask tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "consistent subtask"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + # Results should be identical + assert torch.equal(subtask_tokens1, subtask_tokens2) + assert torch.equal(subtask_mask1, subtask_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): + """Test subtask tokenization works correctly with DataProcessorPipeline.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "subtask instruction"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and both tokenizations were applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + + # Check task tokens + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check subtask tokens + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check shapes + assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,) + assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,) + + +@require_package("transformers") +def test_subtask_not_added_for_unsupported_types(): + """Test that subtask tokens are not added when subtask has unsupported type.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + # Test with integer subtask + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Subtask tokens should NOT be added for unsupported types + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation From 9c24a09665ffe1338a91ee7a05a0e76d10e7e3d4 Mon Sep 17 00:00:00 2001 From: Hirokazu Ishida <38597814+HiroIshida@users.noreply.github.com> Date: Tue, 3 Feb 2026 04:05:58 +0900 Subject: [PATCH 15/43] docs: update document in response to Simplify configs PR (#1596) * docs: update document input/output_shapes -> input/output_features * fix inconsistent quote (suggested by copilot reviewer) * docs: shapes => PolicyFeature * docs: relfect normalization_mapping and remove outdated --- src/lerobot/configs/policies.py | 12 ++++----- src/lerobot/policies/act/configuration_act.py | 23 +++++----------- .../diffusion/configuration_diffusion.py | 23 +++++----------- .../policies/tdmpc/configuration_tdmpc.py | 26 +++++-------------- .../policies/vqbet/configuration_vqbet.py | 23 +++++----------- 5 files changed, 34 insertions(+), 73 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7f326b70b..44b013c29 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - input_shapes: A dictionary defining the shapes of the input data for the policy. - output_shapes: A dictionary defining the shapes of the output data for the policy. - input_normalization_modes: A dictionary with key representing the modality and the value specifies the - normalization mode to apply. - output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to - the original scale. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) """ n_obs_steps: int = 1 diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index 6f6c1c4be..bd89185fd 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig): Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and 'output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - Either: @@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig): This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. `None` means no pretrained weights. diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 54569434a..8322ca337 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig): horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. diff --git a/src/lerobot/policies/tdmpc/configuration_tdmpc.py b/src/lerobot/policies/tdmpc/configuration_tdmpc.py index 3c1a29932..3ec493472 100644 --- a/src/lerobot/policies/tdmpc/configuration_tdmpc.py +++ b/src/lerobot/policies/tdmpc/configuration_tdmpc.py @@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig): camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. + Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`. Args: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google @@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig): is an alternative to using action repeats. If this is set to more than 1, then we require `n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this approach of using multiple steps from the plan is not in the original implementation. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to - match the original implementation. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping - to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" - normalization mode here. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. latent_dim: Observation's latent embedding dimension. diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index 44ada9f17..32906e528 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig): current step and additional steps going back). n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts. action_chunk_size: Action chunk size of each action prediction token. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. From 14a15f90e762170209d283c3545523549841ca3d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 2 Feb 2026 22:14:03 +0100 Subject: [PATCH 16/43] Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873) * fix(RL) add missing config arguments * respond to copilot review * fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic. --------- Co-authored-by: CarolinePascal --- src/lerobot/envs/configs.py | 1 + src/lerobot/processor/hil_processor.py | 22 ++++++++++++---------- src/lerobot/rl/gym_manipulator.py | 12 ++++++++++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index cd88b37bc..9c1c083a4 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -205,6 +205,7 @@ class ObservationConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False display_cameras: bool = False diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 6d44ed8cb..24b5628fa 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") -class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): +class GripperPenaltyProcessorStep(ProcessorStep): """ Applies a penalty for inefficient gripper usage. @@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): penalty: float = -0.01 max_gripper_pos: float = 30.0 - def complementary_data(self, complementary_data: dict) -> dict: + def __call__(self, transition: EnvTransition) -> EnvTransition: """ Calculates the gripper penalty and adds it to the complementary data. Args: - complementary_data: The incoming complementary data, which should contain - raw joint positions. + transition: The incoming environment transition. Returns: - A new complementary data dictionary with the `discrete_penalty` key added. + The modified transition with the penalty added to complementary data. """ - action = self.transition.get(TransitionKey.ACTION) + new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) raw_joint_positions = complementary_data.get("raw_joint_positions") if raw_joint_positions is None: - return complementary_data + return new_transition current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) if current_gripper_pos is None: - return complementary_data + return new_transition # Gripper action is a PolicyAction at this stage gripper_action = action[-1].item() @@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): gripper_penalty = self.penalty * int(gripper_penalty_bool) - # Create new complementary data with penalty info + # Update complementary data with penalty info new_complementary_data = dict(complementary_data) new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data - return new_complementary_data + return new_transition def get_config(self) -> dict[str, Any]: """ diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 3d58ae18f..1c1cb752f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -412,7 +412,10 @@ def make_processors( if cfg.processor.observation.add_current_to_observation: env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot)) - if kinematics_solver is not None: + add_ee_pose = ( + cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation + ) + if kinematics_solver is not None and add_ee_pose: env_pipeline_steps.append( ForwardKinematicsJointsToEEObservation( kinematics=kinematics_solver, @@ -435,7 +438,12 @@ def make_processors( ) # Add gripper penalty processor if gripper config exists and enabled - if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper: + # Only add if max_gripper_pos is explicitly configured (required for normalization) + if ( + cfg.processor.gripper is not None + and cfg.processor.gripper.use_gripper + and cfg.processor.max_gripper_pos is not None + ): env_pipeline_steps.append( GripperPenaltyProcessorStep( penalty=cfg.processor.gripper.gripper_penalty, From a6370dd783c1048096b9596853beccc08a7b0bbd Mon Sep 17 00:00:00 2001 From: Iori Yanokura Date: Tue, 3 Feb 2026 22:17:04 +0900 Subject: [PATCH 17/43] fix(wandb): truncate init tags to 64-character limit (#995) --- src/lerobot/rl/wandb_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 7b7f8a57b..ee30b75df 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.utils.constants import PRETRAINED_MODEL_DIR -def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: +def cfg_to_group( + cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64 +) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" + + def _maybe_truncate(tag: str) -> str: + """Truncate tag to max_tag_length characters if required. + + wandb rejects tags longer than 64 characters. + See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py + """ + if len(tag) <= max_tag_length: + return tag + return tag[:max_tag_length] + lst = [ f"policy:{cfg.policy.type}", f"seed:{cfg.seed}", @@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") + if truncate_tags: + lst = [_maybe_truncate(tag) for tag in lst] return lst if return_list else "-".join(lst) @@ -83,7 +98,7 @@ class WandBLogger: entity=self.cfg.entity, name=self.job_name, notes=self.cfg.notes, - tags=cfg_to_group(cfg, return_list=True), + tags=cfg_to_group(cfg, return_list=True, truncate_tags=True), dir=self.log_dir, config=cfg.to_dict(), # TODO(rcadene): try set to True From 0f392484458cb5ebca0310c0c4c47390a31c80ed Mon Sep 17 00:00:00 2001 From: jwang078 Date: Tue, 3 Feb 2026 13:19:00 -0500 Subject: [PATCH 18/43] Small docstring fix in diffusion configuration (#2847) --- src/lerobot/policies/diffusion/configuration_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8322ca337..8ac0920dd 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -64,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig): use_group_norm: Whether to replace batch normalization with group normalization in the backbone. The group sizes are set to be about 16 (to be precise, feature_dim // 16). spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. - use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view. + use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view. down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. You may provide a variable number of dimensions, therefore also controlling the degree of downsampling. From 97e7e0f9ed8831daee04a6e5f67d777f689c87e4 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:39:58 +0000 Subject: [PATCH 19/43] feat(datasets): improve image transform support (#2885) * improve image transform support * add tests * Add stricter transform check and extra test * improve subclass check --- src/lerobot/datasets/transforms.py | 19 ++++++++++--------- tests/datasets/test_image_transforms.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index beacc48d9..5240619cb 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -216,16 +216,17 @@ class ImageTransformsConfig: def make_transform_from_config(cfg: ImageTransformConfig): - if cfg.type == "Identity": - return v2.Identity(**cfg.kwargs) - elif cfg.type == "ColorJitter": - return v2.ColorJitter(**cfg.kwargs) - elif cfg.type == "SharpnessJitter": + if cfg.type == "SharpnessJitter": return SharpnessJitter(**cfg.kwargs) - elif cfg.type == "RandomAffine": - return v2.RandomAffine(**cfg.kwargs) - else: - raise ValueError(f"Transform '{cfg.type}' is not valid.") + + transform_cls = getattr(v2, cfg.type, None) + if isinstance(transform_cls, type) and issubclass(transform_cls, Transform): + return transform_cls(**cfg.kwargs) + + raise ValueError( + f"Transform '{cfg.type}' is not valid. It must be a class in " + f"torchvision.transforms.v2 or 'SharpnessJitter'." + ) class ImageTransforms(Transform): diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 8a66ceb24..ef7e8c395 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller(): SharpnessJitter((2.0, 0.1)) +def test_make_transform_from_config_with_v2_resize(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Resize) + output = tf(img_tensor) + assert output.shape[-2:] == (32, 32) + + +def test_make_transform_from_config_with_v2_identity(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Identity", kwargs={}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Identity) + output = tf(img_tensor) + assert output.shape == img_tensor.shape + + +def test_make_transform_from_config_invalid_type(): + tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={}) + with pytest.raises(ValueError, match="not valid"): + make_transform_from_config(tf_cfg) + + def test_save_all_transforms(img_tensor_factory, tmp_path): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig(enable=True) From e14bdf57d055e85ebc8a684efd2e4b9a4c7b6a37 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:46:12 +0000 Subject: [PATCH 20/43] Convert tensors to scalars (#2903) Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/modeling_smolvla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index c611e9ba2..60b968a42 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy): actions_is_pad = batch.get("actions_id_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() + loss_dict["losses_after_forward"] = losses.clone().mean().item() if actions_is_pad is not None: in_episode_bound = ~actions_is_pad losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() + loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item() # Remove padding losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() + loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims From 489cb7b6b9a39b569aaf02ff26df6725e9b36285 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Feb 2026 16:58:32 +0100 Subject: [PATCH 21/43] fix(scripts): correct can import check (#2937) --- src/lerobot/scripts/lerobot_setup_can.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py index 55de74724..a31727ea4 100644 --- a/src/lerobot/scripts/lerobot_setup_can.py +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -45,7 +45,7 @@ from dataclasses import dataclass, field import draccus -from lerobot.utils.import_utils import is_package_available +from lerobot.utils.import_utils import _can_available MOTOR_NAMES = { 0x01: "joint_1", @@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig): @draccus.wrap() def setup_can(cfg: CANSetupConfig): - if not is_package_available("can"): + if not _can_available: print("Error: python-can not installed. Install with: pip install python-can") sys.exit(1) From cca0296cd6f0f281c5fd4628e836403628b59a05 Mon Sep 17 00:00:00 2001 From: Stepan Feduniak Date: Tue, 10 Feb 2026 13:55:11 +0100 Subject: [PATCH 22/43] fix(pipeline): use FeatureType for STATE features in Libero processor (#2888) * fix the types * pre-commit --------- Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/processor/env_processor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index 8d42bfdb7..a77e066cf 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep): # copy over non-STATE features for ft, feats in features.items(): - if ft != PipelineFeatureType.STATE: + if ft != FeatureType.STATE: new_features[ft] = feats.copy() # rebuild STATE features @@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep): # add our new flattened state state_feats[OBS_STATE] = PolicyFeature( - key=OBS_STATE, + type=FeatureType.STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] - dtype="float32", - description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), ) - new_features[PipelineFeatureType.STATE] = state_feats + new_features[FeatureType.STATE] = state_feats return new_features From 5eba4ce6f453c2dfe4458b037bf3612df22f81ee Mon Sep 17 00:00:00 2001 From: Aoqun Jin Date: Tue, 10 Feb 2026 21:39:17 +0800 Subject: [PATCH 23/43] Change LIBERO init_state_id when reset. (#2899) * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * pre-commit run --------- Signed-off-by: Aoqun Jin Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/envs/libero.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 96c5cf102..d20dae8ea 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -112,6 +112,7 @@ class LiberoEnv(gym.Env): visualization_height: int = 480, init_states: bool = True, episode_index: int = 0, + n_envs: int = 1, camera_name_mapping: dict[str, str] | None = None, num_steps_wait: int = 10, control_mode: str = "relative", @@ -145,7 +146,9 @@ class LiberoEnv(gym.Env): self.episode_length = episode_length # Load once and keep self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None - self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`. + + self.init_state_id = self.episode_index # tie each sub-env to a fixed init state self._env = self._make_envs_task(task_suite, self.task_id) default_steps = 500 @@ -295,7 +298,8 @@ class LiberoEnv(gym.Env): self._env.seed(seed) raw_obs = self._env.reset() if self.init_states and self._init_states is not None: - raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)]) + self.init_state_id += self._reset_stride # Change init_state_id when reset # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. @@ -373,6 +377,7 @@ def _make_env_fns( init_states=init_states, episode_length=episode_length, episode_index=episode_index, + n_envs=n_envs, control_mode=control_mode, **local_kwargs, ) From d2d01399d6773427347a37be401ec6ea35fa0e15 Mon Sep 17 00:00:00 2001 From: Jai Kumaar Ratadia Date: Tue, 10 Feb 2026 14:18:32 +0000 Subject: [PATCH 24/43] docs: clarify installation steps are sequential, not optional (#2925) * docs: clarify installation steps are sequential, not optional Add intro paragraph noting conda is one path (not the only one) and number the three sections as steps so readers understand miniforge and environment setup are prerequisites, not independent choices. * Update installation guide link for LeRobot Signed-off-by: Jai Kumaar Ratadia * Fix link formatting in installation guide again Signed-off-by: Jai Kumaar Ratadia --------- Signed-off-by: Jai Kumaar Ratadia Co-authored-by: Steven Palma --- docs/source/installation.mdx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 44d8c7034..8cc83843e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,13 +1,15 @@ # Installation -## Install [`miniforge`](https://conda-forge.org/download/) +This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). + +## Step 1: Install [`miniforge`](https://conda-forge.org/download/) ```bash wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" bash Miniforge3-$(uname)-$(uname -m).sh ``` -## Environment Setup +## Step 2: Environment Setup Create a virtual environment with Python 3.10, using conda: @@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -## Install LeRobot 🤗 +## Step 3: Install LeRobot 🤗 ### From Source From 778db19a178b2fee539934419b1475b1f411bac9 Mon Sep 17 00:00:00 2001 From: whats2000 <60466660+whats2000@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:21:40 +0800 Subject: [PATCH 25/43] [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911) * fix(ci): prevent runner group error on fork pushes Add repository check to unbound_deps_tests workflow to ensure aws-general-8-plus runner group is only used on main repository, preventing 'Required runner group not found' errors on forks. * fix(ci): use gating job to prevent runner allocation on forks The previous approach failed because GitHub evaluates runs-on before if conditions. Now using a check-repo job that runs on ubuntu-latest first, and all jobs with special runners depend on it and check its output before being scheduled. * fix(ci): add gating job to full_tests to prevent runner allocation on forks Apply the same gating pattern used in unbound_deps_tests to full_tests.yml to prevent GitHub from trying to allocate custom runners when workflows run on forks. The check-repo job runs first on ubuntu-latest and all jobs with custom runners depend on it and check its output. * fix(ci): add repository check to unbound_deps_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml * fix(ci): add repository check to full_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks * refactor(ci): remove redundant check from gpu-tests job gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped * refactor(ci): remove unnecessary fork check from full-tests job full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it --------- Co-authored-by: Steven Palma --- .github/workflows/full_tests.yml | 8 +++++--- .github/workflows/unbound_deps_tests.yml | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 4dce3121a..fd5e422b3 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -101,9 +101,11 @@ jobs: runs-on: group: aws-general-8-plus if: | - (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || - github.event_name == 'push' || - github.event_name == 'workflow_dispatch' + github.repository == 'huggingface/lerobot' && ( + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + ) outputs: image_tag: ${{ steps.set_tag.outputs.image_tag }} env: diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index a75ecc121..3f4ea3316 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -91,6 +91,7 @@ jobs: name: Build and Push Docker runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME }} env: From 35363c5798d129d7667c2efa43ddfa342639a35a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:35:39 +0100 Subject: [PATCH 26/43] chore(linter): ensure motors module passes MyPy type checks (#2939) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: ensure motors module passes MyPy type checks This commit fixes 62 mypy type errors in the motors module by: - Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead, GroupSyncWrite) to use class-level attribute declarations instead of __init__ body declarations - Adding missing `broadcastPing` method to PacketHandler Protocol - Fixing return type annotations (e.g., `_get_motor_model` returns str, not int) - Fixing parameter types to use `Sequence` for covariant list parameters - Fixing `Mapping` for covariant dict value types in `_normalize` - Updating method signatures to be consistent across parent and child classes (disable_torque, enable_torque, _get_half_turn_homings) - Adding explicit `int()` casts for MotorCalibration arguments - Adding explicit `return None` for functions returning Optional types - Adding type annotations for variables like `data_list: dict[int, int]` - Using `# type: ignore[method-assign]` for intentional monkeypatch - Fixing variable references (using `self.groups` instead of `groups`) Fixes #1723 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 * chore(style): pre-commit after main merge * chore(linter): solve comments * chore(linter): apply pre-commit fixes to damiao * chore(linter): more fixes to damiao --------- Co-authored-by: yurekami Co-authored-by: Claude Opus 4.5 --- pyproject.toml | 6 +- src/lerobot/motors/calibration_gui.py | 10 +- src/lerobot/motors/damiao/damiao.py | 42 ++++- src/lerobot/motors/dynamixel/dynamixel.py | 16 +- src/lerobot/motors/feetech/feetech.py | 18 +-- src/lerobot/motors/motors_bus.py | 183 +++++++++++----------- 6 files changed, 157 insertions(+), 118 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 210d70b6b..c4b1c547e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,9 +360,9 @@ ignore_errors = false module = "lerobot.cameras.*" ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.motors.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.motors.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 02bba454f..3410cb28a 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -221,7 +221,7 @@ class RangeFinderGUI: self.bus = bus self.groups = groups if groups is not None else {"all": list(bus.motors)} - self.group_names = list(groups) + self.group_names = list(self.groups) self.current_group = self.group_names[0] if not bus.is_connected: @@ -230,18 +230,20 @@ class RangeFinderGUI: self.calibration = bus.read_calibration() self.res_table = bus.model_resolution_table self.present_cache = { - m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + m: bus.read("Present_Position", m, normalize=False) + for motors in self.groups.values() + for m in motors } pygame.init() self.font = pygame.font.Font(None, FONT_SIZE) - label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms) self.label_pad = label_pad width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 self.controls_bottom = 10 + SAVE_H self.base_y = self.controls_bottom + TOP_GAP - height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40 self.screen = pygame.display.set_mode((width, height)) pygame.display.set_caption("Motors range finder") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index c79f8d17e..95a9e70d1 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -211,6 +211,9 @@ class DamiaoMotorsBus(MotorsBusBase): logger.info("Starting handshake with motors...") # Drain any pending messages + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + while self.canbus.recv(timeout=0.01): pass @@ -283,6 +286,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -341,6 +348,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -356,6 +367,10 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: CAN message if received, None otherwise """ + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: start_time = time.time() messages_seen = [] @@ -394,10 +409,13 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: Dictionary mapping recv_id to CAN message """ - responses = {} + responses: dict[int, can.Message] = {} expected_set = set(expected_recv_ids) start_time = time.time() + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: # 100us poll timeout @@ -461,6 +479,9 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) motor_type = self._motor_types[motor_name] + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) @@ -488,6 +509,9 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Step 1: Send all MIT control commands for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): motor_id = self._get_motor_id(motor) @@ -656,6 +680,10 @@ class DamiaoMotorsBus(MotorsBusBase): def _batch_refresh(self, motors: list[str]) -> None: """Internal helper to refresh a list of motors and update cache.""" + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Send refresh commands for motor in motors: motor_id = self._get_motor_id(motor) @@ -678,10 +706,14 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): @@ -690,6 +722,8 @@ class DamiaoMotorsBus(MotorsBusBase): elif data_name == "Goal_Position": # Step 1: Send all MIT control commands recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") for motor, value_degrees in values.items(): motor_id = self._get_motor_id(motor) motor_name = self._get_motor_name(motor) @@ -732,9 +766,9 @@ class DamiaoMotorsBus(MotorsBusBase): def record_ranges_of_motion( self, - motors: NameOrID | list[NameOrID] | None = None, + motors: str | list[str] | None = None, display_values: bool = True, - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + ) -> tuple[dict[str, Value], dict[str, Value]]: """ Interactively record the min/max values of each motor in degrees. diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index c6752ee96..bca455dc5 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus): for motor, m in self.motors.items(): calibration[motor] = MotorCalibration( id=m.id, - drive_mode=drive_modes[motor], - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + drive_mode=int(drive_modes[motor]), + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus): if cache: self.calibration = calibration_dict - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) @@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) @@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus): On Dynamixel Motors: Present_Position = Actual_Position + Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None return {id_: data[0] for id_, data in data_list.items()} diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7ce3388b6..58a65310d 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus): self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch - self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler, scs.PortHandler ) self.packet_handler = scs.PacketHandler(protocol_version) @@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus): calibration[motor] = MotorCalibration( id=m.id, drive_mode=0, - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus): On Feetech Motors: Present_Position = Actual_Position - Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus): return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) @@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Lock") self._write(addr, length, motor, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry) @@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus): def _broadcast_ping(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs - data_list = {} + data_list: dict[int, int] = {} status_length = 6 @@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus): if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index c04f718b6..bc3ffb7e2 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -23,6 +23,7 @@ from __future__ import annotations import abc import logging +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC): pass @abc.abstractmethod - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """Write values to multiple motors.""" pass @@ -179,15 +180,16 @@ class Motor: class PortHandler(Protocol): - def __init__(self, port_name): - self.is_open: bool - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool - self.port_name: str - self.ser: serial.Serial + is_open: bool + baudrate: int + packet_start_time: float + packet_timeout: float + tx_time_per_byte: float + is_using: bool + port_name: str + ser: serial.Serial + + def __init__(self, port_name: str) -> None: ... def openPort(self): ... def closePort(self): ... @@ -240,19 +242,22 @@ class PacketHandler(Protocol): def regWriteTxRx(self, port, id, address, length, data): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + def broadcastPing(self, port): ... class GroupSyncRead(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.last_result: bool - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + last_result: bool + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id): ... def removeParam(self, id): ... @@ -265,15 +270,17 @@ class GroupSyncRead(Protocol): class GroupSyncWrite(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id, data): ... def removeParam(self, id): ... @@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motor_model(self, motor: NameOrID) -> int: + def _get_motor_model(self, motor: NameOrID) -> str: if isinstance(motor, str): return self.motors[motor].model elif isinstance(motor, int): @@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]: if motors is None: return list(self.motors) elif isinstance(motors, str): return [motors] - elif isinstance(motors, list): - return motors.copy() + elif isinstance(motors, int): + return [self._id_to_name(motors)] + elif isinstance(motors, Sequence): + return [m if isinstance(m, str) else self._id_to_name(m) for m in motors] else: raise TypeError(motors) - def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]: if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): @@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase): pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: """Enable torque on selected motors. Args: - motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`. + Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. """ pass @contextmanager - def torque_disabled(self, motors: int | str | list[str] | None = None): + def torque_disabled(self, motors: str | list[str] | None = None): """Context-manager that guarantees torque is re-enabled. This helper is useful to temporarily disable torque when configuring motors. @@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase): """ pass - def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None: """Restore factory calibration for the selected motors. Homing offset is set to ``0`` and min/max position limits are set to the full usable range. The in-memory :pyattr:`calibration` is cleared. Args: - motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default) resets every motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - for motor in motors: + for motor in motor_names: model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 self.write("Homing_Offset", motor, 0, normalize=False) @@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase): self.calibration = {} - def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + def set_half_turn_homings( + self, motors: NameOrID | Sequence[NameOrID] | None = None + ) -> dict[NameOrID, Value]: """Centre each motor range around its current position. The function computes and writes a homing offset such that the present position becomes exactly one @@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase): motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). Returns: - dict[NameOrID, Value]: Mapping *motor → written homing offset*. + dict[str, Value]: Mapping *motor name → written homing offset*. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - self.reset_calibration(motors) - actual_positions = self.sync_read("Present_Position", motors, normalize=False) + self.reset_calibration(motor_names) + actual_positions = self.sync_read("Present_Position", motor_names, normalize=False) homing_offsets = self._get_half_turn_homings(actual_positions) for motor, offset in homing_offsets.items(): self.write("Homing_Offset", motor, offset) @@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase): pass def record_ranges_of_motion( - self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: """Interactively record the min/max encoder values of each motor. Move the joints by hand (with torque disabled) while the method streams live positions. Press @@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase): display_values (bool, optional): When `True` (default) a live table is printed to the console. Returns: - tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the extreme values observed for each motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - start_positions = self.sync_read("Present_Position", motors, normalize=False) + start_positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = start_positions.copy() maxes = start_positions.copy() user_pressed_enter = False while not user_pressed_enter: - positions = self.sync_read("Present_Position", motors, normalize=False) + positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} if display_values: print("\n-------------------------------------------") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") - for motor in motors: + for motor in motor_names: print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") if enter_pressed(): @@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase): if display_values and not user_pressed_enter: # Move cursor up to overwrite the previous output - move_cursor_up(len(motors) + 3) + move_cursor_up(len(motor_names) + 3) - same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]] if same_min_max: raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") @@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: - return + return None if self._is_error(error): if raise_on_error: raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: - return + return None return model_number @@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase): err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - id_value = self._decode_sign(data_name, {id_: value}) + decoded = self._decode_sign(data_name, {id_: value}) if normalize and data_name in self.normalized_data: - id_value = self._normalize(id_value) + normalized = self._normalize(decoded) + return normalized[id_] - return id_value[id_] + return decoded[id_] def _read( self, @@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase): num_retry: int = 0, raise_on_error: bool = True, err_msg: str = "", - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) + int_value = int(value) if normalize and data_name in self.normalized_data: - value = self._unnormalize({id_: value})[id_] + int_value = self._unnormalize({id_: value})[id_] - value = self._encode_sign(data_name, {id_: value})[id_] + int_value = self._encode_sign(data_name, {id_: int_value})[id_] - err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries." + self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( self, @@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase): def sync_read( self, data_name: str, - motors: str | list[str] | None = None, + motors: NameOrID | Sequence[NameOrID] | None = None, *, normalize: bool = True, num_retry: int = 0, @@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase): Args: data_name (str): Register name. - motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor. normalize (bool, optional): Normalisation flag. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. @@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase): addr, length = get_address(self.model_ctrl_table, model, data_name) err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - ids_values, _ = self._sync_read( + raw_ids_values, _ = self._sync_read( addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg ) - ids_values = self._decode_sign(data_name, ids_values) + decoded = self._decode_sign(data_name, raw_ids_values) if normalize and data_name in self.normalized_data: - ids_values = self._normalize(ids_values) + normalized = self._normalize(decoded) + return {self._id_to_name(id_): value for id_, value in normalized.items()} - return {self._id_to_name(id_): value for id_, value in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in decoded.items()} def _sync_read( self, @@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase): num_retry (int, optional): Retry attempts. Defaults to `0`. """ - ids_values = self._get_ids_values_dict(values) - models = [self._id_to_model(id_) for id_ in ids_values] + raw_ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in raw_ids_values] if self._has_different_ctrl_tables: assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) + int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()} if normalize and data_name in self.normalized_data: - ids_values = self._unnormalize(ids_values) + int_ids_values = self._unnormalize(raw_ids_values) - ids_values = self._encode_sign(data_name, ids_values) + int_ids_values = self._encode_sign(data_name, int_ids_values) - err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries." + self._sync_write( + addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) def _sync_write( self, From 1ba3975020c8079630ff7dda8fe983ad473d7c12 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:49:30 +0100 Subject: [PATCH 27/43] chore: use is_connected decorators (#2948) * chore: use is_connected decorators * chore(robots): add is_connected to bi setups too --- src/lerobot/cameras/opencv/camera_opencv.py | 19 ++++++--------- .../cameras/reachy2_camera/reachy2_camera.py | 14 ++++------- .../cameras/realsense/camera_realsense.py | 23 +++++++------------ src/lerobot/cameras/zmq/camera_zmq.py | 16 +++++-------- src/lerobot/motors/damiao/damiao.py | 16 ++++--------- .../bi_openarm_follower.py | 5 ++++ .../robots/bi_so_follower/bi_so_follower.py | 5 ++++ .../openarm_follower/openarm_follower.py | 15 ++++-------- .../bi_openarm_leader/bi_openarm_leader.py | 4 ++++ .../bi_so_leader/bi_so_leader.py | 4 +++- .../openarm_leader/openarm_leader.py | 11 ++++----- 11 files changed, 57 insertions(+), 75 deletions(-) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index d581e1425..465ba7a1b 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -32,7 +32,8 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 # type: ignore # TODO: add type stubs for OpenCV -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..utils import get_cv2_backend, get_cv2_rotation @@ -132,6 +133,7 @@ class OpenCVCamera(Camera): """Checks if the camera is currently connected and opened.""" return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened() + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the OpenCV camera specified in the configuration. @@ -148,8 +150,6 @@ class OpenCVCamera(Camera): ConnectionError: If the specified camera index/path is not found or fails to open. RuntimeError: If the camera opens but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") # Use 1 thread for OpenCV operations to avoid potential conflicts or # blocking in multi-threaded applications, especially during data collection. @@ -178,6 +178,7 @@ class OpenCVCamera(Camera): logger.info(f"{self} connected.") + @check_if_not_connected def _configure_capture_settings(self) -> None: """ Applies the specified FOURCC, FPS, width, and height settings to the connected camera. @@ -197,8 +198,6 @@ class OpenCVCamera(Camera): to the requested value. DeviceNotConnectedError: If the camera is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") # Set FOURCC first (if specified) as it can affect available FPS/resolution options if self.config.fourcc is not None: @@ -348,6 +347,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -374,9 +374,6 @@ class OpenCVCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -490,6 +487,7 @@ class OpenCVCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -512,8 +510,6 @@ class OpenCVCamera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -533,6 +529,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -548,8 +545,6 @@ class OpenCVCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 5cede466d..0c1dc43d8 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" import cv2 # type: ignore # TODO: add type stubs for OpenCV import numpy as np # type: ignore # TODO: add type stubs for numpy +from lerobot.utils.decorators import check_if_not_connected from lerobot.utils.import_utils import _reachy2_sdk_available if TYPE_CHECKING or _reachy2_sdk_available: @@ -123,6 +124,7 @@ class Reachy2Camera(Camera): """ raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -136,9 +138,6 @@ class Reachy2Camera(Camera): """ start_time = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -184,6 +183,7 @@ class Reachy2Camera(Camera): return frame + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Same as read() @@ -197,11 +197,10 @@ class Reachy2Camera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") return self.read() + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -219,8 +218,6 @@ class Reachy2Camera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.latest_frame is None or self.latest_timestamp is None: raise RuntimeError(f"{self} has not captured any frames yet.") @@ -233,6 +230,7 @@ class Reachy2Camera(Camera): return self.latest_frame + @check_if_not_connected def disconnect(self) -> None: """ Stops the background read thread (if running). @@ -240,8 +238,6 @@ class Reachy2Camera(Camera): Raises: DeviceNotConnectedError: If the camera is already disconnected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} not connected.") if self.cam_manager is not None: self.cam_manager.disconnect() diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index e47f25381..d599cdce0 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -30,7 +30,8 @@ try: except Exception as e: logging.info(f"Could not import realsense: {e}") -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -152,6 +153,7 @@ class RealSenseCamera(Camera): """Checks if the camera pipeline is started and streams are active.""" return self.rs_pipeline is not None and self.rs_profile is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the RealSense camera specified in the configuration. @@ -169,8 +171,6 @@ class RealSenseCamera(Camera): ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. RuntimeError: If the pipeline starts but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") self.rs_pipeline = rs.pipeline() rs_config = rs.config() @@ -290,6 +290,7 @@ class RealSenseCamera(Camera): if self.use_depth: rs_config.enable_stream(rs.stream.depth) + @check_if_not_connected def _configure_capture_settings(self) -> None: """Sets fps, width, and height from device stream if not already configured. @@ -299,8 +300,6 @@ class RealSenseCamera(Camera): Raises: DeviceNotConnectedError: If device is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") if self.rs_profile is None: raise RuntimeError(f"{self}: rs_profile must be initialized before use.") @@ -320,6 +319,7 @@ class RealSenseCamera(Camera): self.width, self.height = actual_width, actual_height self.capture_width, self.capture_height = actual_width, actual_height + @check_if_not_connected def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]: """ Reads a single frame (depth) synchronously from the camera. @@ -345,9 +345,6 @@ class RealSenseCamera(Camera): f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -374,6 +371,7 @@ class RealSenseCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. @@ -403,9 +401,6 @@ class RealSenseCamera(Camera): f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -534,6 +529,7 @@ class RealSenseCamera(Camera): self.new_frame_event.clear() # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame data (color) asynchronously. @@ -556,8 +552,6 @@ class RealSenseCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread died unexpectedly or another error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -578,6 +572,7 @@ class RealSenseCamera(Camera): return frame # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). @@ -593,8 +588,6 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index f29e16a28..16523b50a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -34,7 +34,8 @@ import cv2 import numpy as np from numpy.typing import NDArray -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -104,6 +105,7 @@ class ZMQCamera(Camera): """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """Connect to ZMQ camera server. @@ -111,8 +113,6 @@ class ZMQCamera(Camera): warmup (bool): If True, waits for the camera to provide at least one valid frame before returning. Defaults to True. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") logger.info(f"Connecting to {self}...") @@ -211,6 +211,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -228,9 +229,6 @@ class ZMQCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -301,6 +299,7 @@ class ZMQCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -317,8 +316,6 @@ class ZMQCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread is not running. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -335,6 +332,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -350,8 +348,6 @@ class ZMQCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index 95a9e70d1..a454130a6 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -23,6 +23,7 @@ from copy import deepcopy from functools import cached_property from typing import TYPE_CHECKING, Any, TypedDict +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: @@ -36,7 +37,6 @@ else: import numpy as np -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import enter_pressed, move_cursor_up @@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase): """Check if the CAN bus is connected.""" return self._is_connected and self.canbus is not None + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """ Open the CAN bus and initialize communication. @@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: handshake: If True, ping all motors to verify they're present """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected." - ) try: # Auto-detect interface type based on port name @@ -249,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase): ) logger.info("Handshake successful. All motors ready.") + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """ Close the CAN bus connection. @@ -256,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: disable_torque: If True, disable torque on all motors before disconnecting """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") if disable_torque: try: @@ -586,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase): except Exception as e: logger.warning(f"Failed to decode response from {motor}: {e}") + @check_if_not_connected def read(self, data_name: str, motor: str) -> Value: """Read a value from a single motor. Positions are always in degrees.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Refresh motor to get latest state msg = self._refresh_motor(motor) @@ -619,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase): raise ValueError(f"Unknown data_name: {data_name}") return mapping[data_name] + @check_if_not_connected def write( self, data_name: str, @@ -629,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase): Write a value to a single motor. Positions are always in degrees. Can write 'Goal_Position', 'Kp', or 'Kd'. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if data_name in ("Kp", "Kd"): self._gains[motor][data_name.lower()] = float(value) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 466eb07e5..2e3885e67 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_openarm_follower import BiOpenArmFollowerConfig @@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 09f849772..28c58b898 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_so_follower import BiSOFollowerConfig @@ -96,6 +97,7 @@ class BiSOFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -116,6 +118,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -129,6 +132,7 @@ class BiSOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { @@ -148,6 +152,7 @@ class BiSOFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c221afd10..d6794a226 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -119,6 +119,7 @@ class OpenArmFollower(Robot): """Check if robot is connected.""" return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the robot and optionally calibrate. @@ -126,8 +127,6 @@ class OpenArmFollower(Robot): We assume that at connection time, the arms are in a safe rest position, and torque can be safely disabled to run calibration if needed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -219,6 +218,7 @@ class OpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Get current observation from robot including position, velocity, and torque. @@ -228,9 +228,6 @@ class OpenArmFollower(Robot): """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict: dict[str, Any] = {} states = self.bus.sync_read_all_states() @@ -253,6 +250,7 @@ class OpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -272,8 +270,6 @@ class OpenArmFollower(Robot): Returns: The action actually sent (potentially clipped) """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -333,10 +329,9 @@ class OpenArmFollower(Robot): return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): """Disconnect from robot.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus self.bus.disconnect(self.config.disable_torque_on_disconnect) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index c4383293f..74b0c9b83 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader from ..teleoperator import Teleoperator @@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: action_dict = {} @@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 90bf2a92d..e84ac6f50 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,7 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig -from lerobot.utils.decorators import check_if_not_connected +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index edf4d7090..d9eaabe0f 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_openarm_leader import OpenArmLeaderConfig @@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator): """Check if teleoperator is connected.""" return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the teleoperator. @@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator): For manual control, we disable torque after connecting so the arm can be moved by hand. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get current action from the leader arm. @@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator): Reads all motor states (pos/vel/torque) in one CAN refresh cycle. """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") action_dict: dict[str, Any] = {} @@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + @check_if_not_connected def disconnect(self) -> None: """Disconnect from teleoperator.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus # For manual control, ensure torque is disabled before disconnecting From 3c84d271d53c9ca972cda8fce3b3f715ec813817 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 18:40:50 +0100 Subject: [PATCH 28/43] fix(motors): use decorator to fix precommit (#2951) --- src/lerobot/motors/damiao/damiao.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index a454130a6..ae619f159 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -700,14 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + @check_if_not_connected def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): From fc8a388a2538937992bd8b28bc7ac909ebd1b9a0 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 11 Feb 2026 13:57:25 +0100 Subject: [PATCH 29/43] feat(cameras): make backend configurable to the CLI (#2945) * feat(cameras): make backend configurable to the CLI * chore(cameras): address feedback * feat(Enum error messages): adding better instanciation error messages for Enum classes * chore(Enum error messages): propagating Enum error messages to all camera classes * chore(comments): removing superfluous comments * chore(format): applying ruff checks --------- Co-authored-by: CarolinePascal --- src/lerobot/cameras/__init__.py | 2 +- src/lerobot/cameras/configs.py | 23 +++++++++++++++++++ src/lerobot/cameras/opencv/camera_opencv.py | 4 ++-- .../cameras/opencv/configuration_opencv.py | 23 ++++++------------- .../configuration_reachy2_camera.py | 5 +--- .../realsense/configuration_realsense.py | 16 ++----------- src/lerobot/cameras/utils.py | 12 ---------- src/lerobot/cameras/zmq/configuration_zmq.py | 5 +--- 8 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index 1488cd89e..cbf1f11bf 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from .camera import Camera -from .configs import CameraConfig, ColorMode, Cv2Rotation +from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation from .utils import make_cameras_from_configs diff --git a/src/lerobot/cameras/configs.py b/src/lerobot/cameras/configs.py index 056eec314..987b74775 100644 --- a/src/lerobot/cameras/configs.py +++ b/src/lerobot/cameras/configs.py @@ -25,6 +25,10 @@ class ColorMode(str, Enum): RGB = "rgb" BGR = "bgr" + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.") + class Cv2Rotation(int, Enum): NO_ROTATION = 0 @@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum): ROTATE_180 = 180 ROTATE_270 = -90 + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.") + + +# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html +class Cv2Backends(int, Enum): + ANY = 0 + V4L2 = 200 + DSHOW = 700 + PVAPI = 800 + ANDROID = 1000 + AVFOUNDATION = 1200 + MSMF = 1400 + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.") + @dataclass(kw_only=True) class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 465ba7a1b..10b3f21da 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -36,7 +36,7 @@ from lerobot.utils.decorators import check_if_already_connected, check_if_not_co from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera -from ..utils import get_cv2_backend, get_cv2_rotation +from ..utils import get_cv2_rotation from .configuration_opencv import ColorMode, OpenCVCameraConfig # NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, @@ -118,7 +118,7 @@ class OpenCVCamera(Camera): self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) - self.backend: int = get_cv2_backend() + self.backend: int = config.backend if self.height and self.width: self.capture_width, self.capture_height = self.width, self.height diff --git a/src/lerobot/cameras/opencv/configuration_opencv.py b/src/lerobot/cameras/opencv/configuration_opencv.py index 37a42861c..8ae57fe3c 100644 --- a/src/lerobot/cameras/opencv/configuration_opencv.py +++ b/src/lerobot/cameras/opencv/configuration_opencv.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from pathlib import Path -from ..configs import CameraConfig, ColorMode, Cv2Rotation +from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation -__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"] +__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"] @CameraConfig.register_subclass("opencv") @@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig): rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. warmup_s: Time reading frames before returning from connect (in seconds) fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect). + backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY. Note: - Only 3-channel color output (RGB/BGR) is currently supported. @@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig): rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION warmup_s: int = 1 fourcc: str | None = None + backend: Cv2Backends = Cv2Backends.ANY def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) + self.backend = Cv2Backends(self.backend) if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4): raise ValueError( diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py index ca6db4f03..b40bfe71b 100644 --- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig): f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided." ) - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py index a094128bc..71b083b00 100644 --- a/src/lerobot/cameras/realsense/configuration_realsense.py +++ b/src/lerobot/cameras/realsense/configuration_realsense.py @@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) values = (self.fps, self.width, self.height) if any(v is not None for v in values) and any(v is None for v in values): diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index c0e7b6284..7fb2c3bb1 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import cast from lerobot.utils.import_utils import make_device_from_device_class @@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None: return int(cv2.ROTATE_90_COUNTERCLOCKWISE) else: return None - - -def get_cv2_backend() -> int: - import cv2 - - if platform.system() == "Windows": - return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION - # elif platform.system() == "Darwin": # macOS - # return cv2.CAP_AVFOUNDATION - else: # Linux and others - return int(cv2.CAP_ANY) diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 4e7732cfc..13690e14c 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) if self.timeout_ms <= 0: raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") From 3615160d891f00a1cb8258ed8f81d327049b640e Mon Sep 17 00:00:00 2001 From: taken-yjyoon Date: Fri, 13 Feb 2026 02:13:51 +0900 Subject: [PATCH 30/43] fix(typo): Fixing wrong argparse examples in the comments (using 'True' not 'true') (#1040) Co-authored-by: juni <> --- src/lerobot/processor/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 97ec716ff..8de376928 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): Args: save_directory: The directory where the pipeline will be saved. If None, saves to HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. - repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`. push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. card_kwargs: Additional arguments passed to the card template to customize the card. config_filename: The name of the JSON configuration file. If None, a name is From adebbcf090b47913d3f2e27bb27feccab174f2fc Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Thu, 12 Feb 2026 18:56:04 +0100 Subject: [PATCH 31/43] fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949) * fix(edit dataset operation): fixing dataset tools CLI operation type specification * test(edit dataset operation): adding tests for dataset tools operation type specification * chore(format): running pre-commit * chore(backward compatibility): adding a type property in OperationConfig for backward compatibility Signed-off-by: Caroline Pascal --- src/lerobot/scripts/lerobot_edit_dataset.py | 49 +++++++------- tests/scripts/test_edit_dataset_parsing.py | 71 +++++++++++++++++++++ 2 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 tests/scripts/test_edit_dataset_parsing.py diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 2ca9c520d..7c222ac6c 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -109,11 +109,14 @@ Using JSON config file: --config_path path/to/edit_config.json """ +import abc import logging import shutil from dataclasses import dataclass from pathlib import Path +import draccus + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, @@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging @dataclass -class DeleteEpisodesConfig: - type: str = "delete_episodes" +class OperationConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@OperationConfig.register_subclass("delete_episodes") +@dataclass +class DeleteEpisodesConfig(OperationConfig): episode_indices: list[int] | None = None +@OperationConfig.register_subclass("split") @dataclass -class SplitConfig: - type: str = "split" +class SplitConfig(OperationConfig): splits: dict[str, float | list[int]] | None = None +@OperationConfig.register_subclass("merge") @dataclass -class MergeConfig: - type: str = "merge" +class MergeConfig(OperationConfig): repo_ids: list[str] | None = None +@OperationConfig.register_subclass("remove_feature") @dataclass -class RemoveFeatureConfig: - type: str = "remove_feature" +class RemoveFeatureConfig(OperationConfig): feature_names: list[str] | None = None +@OperationConfig.register_subclass("modify_tasks") @dataclass -class ModifyTasksConfig: - type: str = "modify_tasks" +class ModifyTasksConfig(OperationConfig): new_task: str | None = None episode_tasks: dict[str, str] | None = None +@OperationConfig.register_subclass("convert_image_to_video") @dataclass -class ConvertImageToVideoConfig: - type: str = "convert_image_to_video" +class ConvertImageToVideoConfig(OperationConfig): output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -177,14 +187,7 @@ class ConvertImageToVideoConfig: @dataclass class EditDatasetConfig: repo_id: str - operation: ( - DeleteEpisodesConfig - | SplitConfig - | MergeConfig - | RemoveFeatureConfig - | ModifyTasksConfig - | ConvertImageToVideoConfig - ) + operation: OperationConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: - raise ValueError( - f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" - ) + available = ", ".join(OperationConfig.get_known_choices()) + raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") def main() -> None: diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py new file mode 100644 index 000000000..bf7386b52 --- /dev/null +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import draccus +import pytest + +from lerobot.scripts.lerobot_edit_dataset import ( + ConvertImageToVideoConfig, + DeleteEpisodesConfig, + EditDatasetConfig, + MergeConfig, + ModifyTasksConfig, + OperationConfig, + RemoveFeatureConfig, + SplitConfig, +) + + +def parse_cfg(cli_args: list[str]) -> EditDatasetConfig: + """Helper to parse CLI args into an EditDatasetConfig via draccus.""" + return draccus.parse(EditDatasetConfig, args=cli_args) + + +class TestOperationTypeParsing: + """Test that --operation.type correctly selects the right config subclass.""" + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_operation_type_resolves_correct_class(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + assert isinstance(cfg.operation, expected_cls), ( + f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" + ) + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_get_choice_name_roundtrips(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) + assert resolved_name == type_name From 6600b60e7f5cc7476ddc34beaaf0e0692f82e4b6 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:49:01 +0100 Subject: [PATCH 32/43] always use degrees (#2968) --- src/lerobot/robots/so_follower/config_so_follower.py | 2 +- src/lerobot/teleoperators/so_leader/config_so_leader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index e9ce27123..1ee589bda 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -40,7 +40,7 @@ class SOFollowerConfig: cameras: dict[str, CameraConfig] = field(default_factory=dict) # Set to `True` for backward compatibility with previous policies/dataset - use_degrees: bool = False + use_degrees: bool = True @RobotConfig.register_subclass("so101_follower") diff --git a/src/lerobot/teleoperators/so_leader/config_so_leader.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py index dd55196d7..2b4f782a7 100644 --- a/src/lerobot/teleoperators/so_leader/config_so_leader.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -28,7 +28,7 @@ class SOLeaderConfig: port: str # Whether to use degrees for angles - use_degrees: bool = False + use_degrees: bool = True @TeleoperatorConfig.register_subclass("so101_leader") From 51d3822d75491507561b5f11db0e62d56b342d93 Mon Sep 17 00:00:00 2001 From: masato-ka Date: Wed, 18 Feb 2026 04:09:42 +0900 Subject: [PATCH 33/43] feat(datasets): Add info operation to lerobot-edit-dataset command (#2917) * Add New featrue to lerobot_edit_datset.py that show dataset information. * Fix to draccus error when happen give only --operation.type=info * Updating test and documents regarding lerobot-edit-dataset info function. * Updating documents regarding lerobot-edit-dataset extract function. option name in document is mistake. * feat(datasets): Update to align formatting with pre-commit.(#2917) Update to align formatting by pre-commit. --------- Co-authored-by: Caroline Pascal --- docs/source/using_dataset_tools.mdx | 25 ++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 65 +++++++++++++++++++++ tests/scripts/test_edit_dataset_parsing.py | 3 + 3 files changed, 93 insertions(+) diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index 9e662604e..f7fc9be20 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets: 4. **Add Features** - Add new features to a dataset 5. **Remove Features** - Remove features from a dataset 6. **Convert to Video** - Convert image-based datasets to video format for efficient storage +7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc. The core implementation is in `lerobot.datasets.dataset_tools`. An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. @@ -156,6 +157,30 @@ lerobot-edit-dataset \ **Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved. +### Show the information of datasets + +Show the information of datasets such as number of episode, number of frame, File size and so on. +No change will be made to the dataset + +```bash + +# Show dataset information without feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + +# Show dataset information with feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +``` + +**Parameters:** + +- `parameters`: The flag to control show or no show dataset information with feature details.(default=false) + ### Push to Hub Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 7c222ac6c..06e256fa2 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -104,6 +104,18 @@ Convert image dataset to video format and push to hub: --operation.type convert_image_to_video \ --push_to_hub true +Show dataset information: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +Show dataset information without feature details: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features false + Using JSON config file: python -m lerobot.scripts.lerobot_edit_dataset \ --config_path path/to/edit_config.json @@ -112,6 +124,7 @@ Using JSON config file: import abc import logging import shutil +import sys from dataclasses import dataclass from pathlib import Path @@ -184,6 +197,13 @@ class ConvertImageToVideoConfig(OperationConfig): max_frames_per_batch: int | None = None +@OperationConfig.register_subclass("info") +@dataclass +class InfoConfig(OperationConfig): + type: str = "info" + show_features: bool = False + + @dataclass class EditDatasetConfig: repo_id: str @@ -436,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: logging.info("Dataset saved locally (not pushed to hub)") +def _get_dataset_size(repo_path): + import os + + total = 0 + with os.scandir(repo_path) as it: + for entry in it: + if entry.is_file(): + total += entry.stat().st_size + elif entry.is_dir(): + total += _get_dataset_size(entry.path) + return total + + +def handle_info(cfg: EditDatasetConfig): + if not isinstance(cfg.operation, InfoConfig): + raise ValueError("Operation config must be InfoConfig") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + sys.stdout.write(f"======Info {dataset.meta.repo_id}\n") + sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n") + sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n") + sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n") + sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n") + sys.stdout.write( + f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n" + ) + sys.stdout.write( + f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n" + ) + sys.stdout.write(f"FPS: {dataset.meta.fps}\n") + + total_file_size = _get_dataset_size(dataset.root) + sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n") + if cfg.operation.show_features: + import json + + feature_dump_str = json.dumps( + dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ") + ) + sys.stdout.write("Features:\n") + sys.stdout.write(f"{feature_dump_str}\n") + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: operation_type = cfg.operation.type @@ -452,6 +515,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) + elif operation_type == "info": + handle_info(cfg) else: available = ", ".join(OperationConfig.get_known_choices()) raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index bf7386b52..8800b92ee 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -21,6 +21,7 @@ from lerobot.scripts.lerobot_edit_dataset import ( ConvertImageToVideoConfig, DeleteEpisodesConfig, EditDatasetConfig, + InfoConfig, MergeConfig, ModifyTasksConfig, OperationConfig, @@ -46,6 +47,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_operation_type_resolves_correct_class(self, type_name, expected_cls): @@ -63,6 +65,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_get_choice_name_roundtrips(self, type_name, expected_cls): From 1c388c0002c609ca783bf42729a1e41532a1fba0 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 17 Feb 2026 23:37:46 +0100 Subject: [PATCH 34/43] (Chore) Bump upper bound for torch version (#2897) * Bump upper torch version bound * Apply suggestion from @Copilot Signed-off-by: Vladislav Sovrasov * Update ref state dicts for schedulers * Support older than 2.8 torch versions * Fix precommit --------- Signed-off-by: Vladislav Sovrasov --- pyproject.toml | 6 +++--- tests/optim/test_schedulers.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4b1c547e..e5431ada3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ dependencies = [ "pyserial>=3.5,<4.0", "wandb>=0.24.0,<0.25.0", - "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency - "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency - "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency + "torch>=2.2.1,<2.11.0", # TODO: Bump dependency + "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency + "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=1.1.1,<2.0.0", diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 1e566a6ba..224613416 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR from lerobot.optim.schedulers import ( @@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict From af036ce57e8ce2750f7fa57f4262c87a013bcdff Mon Sep 17 00:00:00 2001 From: Sota Nakamura <49087984+sotanakamura@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:05:51 +0900 Subject: [PATCH 35/43] fix(scripts): serve grpc for a web viewer (#2881) * serve grpc for a web viewer * add help * remove ip detection * fix comment * pass grpc_port * fix(CLI): fixing CLI display-compressed-images argument 1/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal * fix(CLI): fixing CLI display-compressed-images argument 2/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal --------- Signed-off-by: Caroline Pascal Co-authored-by: Caroline Pascal Co-authored-by: HUANG TZU-CHUN Co-authored-by: Steven Palma --- src/lerobot/scripts/lerobot_dataset_viz.py | 37 +++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2cd48eab8..29d64554f 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd ``` - Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) ``` distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --mode distant \ - --ws-port 9087 + --grpc-port 9876 -local$ rerun ws://localhost:9087 +local$ rerun rerun+http://IP:GRPC_PORT/proxy ``` """ @@ -75,6 +73,7 @@ import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD +from lerobot.utils.utils import init_logging def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: @@ -93,10 +92,11 @@ def visualize_dataset( num_workers: int = 0, mode: str = "local", web_port: int = 9090, - ws_port: int = 9087, + grpc_port: int = 9876, save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, + **kwargs, ) -> Path | None: if save: assert output_dir is not None, ( @@ -126,7 +126,9 @@ def visualize_dataset( gc.collect() if mode == "distant": - rr.serve_web_viewer(open_browser=False, web_port=web_port) + server_uri = rr.serve_grpc(grpc_port=grpc_port) + logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy") + rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri) logging.info("Logging to Rerun") @@ -226,7 +228,7 @@ def main(): "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " "'distant' creates a server on the distant machine where the data is stored. " - "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine." ), ) parser.add_argument( @@ -238,8 +240,13 @@ def main(): parser.add_argument( "--ws-port", type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", + help="deprecated, please use --grpc-port instead.", + ) + parser.add_argument( + "--grpc-port", + type=int, + default=9876, + help="gRPC port for rerun.io when `--mode distant` is set.", ) parser.add_argument( "--save", @@ -265,9 +272,7 @@ def main(): parser.add_argument( "--display-compressed-images", - type=bool, - required=True, - default=False, + action="store_true", help="If set, display compressed images in Rerun instead of uncompressed ones.", ) @@ -277,6 +282,14 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") + if kwargs["ws_port"] is not None: + logging.warning( + "--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead." + ) + logging.warning("Setting grpc_port to ws_port value.") + kwargs["grpc_port"] = kwargs.pop("ws_port") + + init_logging() logging.info("Loading dataset") dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) From fcbf550952b3794e425e20d01ec76475be54be4e Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Wed, 18 Feb 2026 18:27:40 +0800 Subject: [PATCH 36/43] fix(docs): update environment variable name to HF_LEROBOT_HOME in docstring (#2973) Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 36bffa190..360ed8d30 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset will be stored under root/repo_id. root (Path | None, optional): Local directory to use for downloading/writing files. You can also - set the LEROBOT_HOME environment variable to point to a different location. Defaults to + set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to '~/.cache/huggingface/lerobot'. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. From b22e0315b05447efbc9a0eb1d612192aad0337c2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 17:32:25 +0100 Subject: [PATCH 37/43] fix(utils): more conservative sleep_margin default value in precise_sleep (#2985) --- src/lerobot/utils/robot_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index 28c8e7c49..656dc2649 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -16,14 +16,14 @@ import platform import time -def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003): +def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005): """ Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage. Parameters: - seconds: duration to wait - spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms - - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms + - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms Note: The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case. From 89bd58a9a26ec5820df13866b6ebc1670ed8cd83 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 18:22:35 +0100 Subject: [PATCH 38/43] chore(scripts): warn if we don't respect the target FPS (#2986) --- src/lerobot/scripts/lerobot_record.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 0b39e6fff..216ab22a6 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -398,7 +398,14 @@ def record_loop( ) dt_s = time.perf_counter() - start_loop_t - precise_sleep(max(1 / fps - dt_s, 0.0)) + + sleep_time_s: float = 1 / fps - dt_s + if sleep_time_s < 0: + logging.warning( + f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + precise_sleep(max(sleep_time_s, 0.0)) timestamp = time.perf_counter() - start_episode_t From aaf37070587581b3ffa8a28b6c134e846afe3a2e Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 18 Feb 2026 19:16:53 +0100 Subject: [PATCH 39/43] fix(filtering): fixing episodes filtering in load_nested_dataset to always use .from_parquet() (#2982) --- src/lerobot/datasets/utils.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 321ecedd5..da186bf30 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -122,19 +122,9 @@ def load_nested_dataset( raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") with SuppressProgressBars(): - # When no filtering needed, Dataset uses memory-mapped loading for efficiency - # PyArrow loads the entire dataset into memory - if episodes is None: - return Dataset.from_parquet([str(path) for path in paths], features=features) - - arrow_dataset = pa_ds.dataset(paths, format="parquet") - filter_expr = pa_ds.field("episode_index").isin(episodes) - table = arrow_dataset.to_table(filter=filter_expr) - - if features is not None: - table = table.cast(features.arrow_schema) - - return Dataset(table) + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) def get_parquet_num_frames(parquet_path: str | Path) -> int: From bc38261321f377621a05595914798023bc05d301 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 20:05:15 +0100 Subject: [PATCH 40/43] feat(robots): use read_latest() camera (#2987) * feat(robots): use read_latest() camera * fix(test): add read_latest reachy cam mock --- src/lerobot/cameras/camera.py | 2 +- src/lerobot/cameras/opencv/camera_opencv.py | 2 +- src/lerobot/cameras/reachy2_camera/reachy2_camera.py | 2 +- src/lerobot/cameras/realsense/camera_realsense.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- src/lerobot/robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/omx_follower/omx_follower.py | 2 +- src/lerobot/robots/openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 2 +- tests/robots/test_reachy2.py | 1 + 14 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index 2894e0215..2a53d2544 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -150,7 +150,7 @@ class Camera(abc.ABC): """ pass - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 10b3f21da..f3289ddc7 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -530,7 +530,7 @@ class OpenCVCamera(Camera): return frame @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 0c1dc43d8..9bef957bc 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -201,7 +201,7 @@ class Reachy2Camera(Camera): return self.read() @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index d599cdce0..d80ec8093 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -573,7 +573,7 @@ class RealSenseCamera(Camera): # NOTE(Steven): Missing implementation for depth for now @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 5fd9c4d1d..e8269ae46 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -140,7 +140,7 @@ class HopeJrArm(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 1e5c72b72..a05c4bbcb 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -171,7 +171,7 @@ class HopeJrHand(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index fee0adba9..53a32beed 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -193,7 +193,7 @@ class KochFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 54848f49d..9d11a000f 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -360,7 +360,7 @@ class LeKiwi(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index a171affbd..e0b612c60 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -176,7 +176,7 @@ class OmxFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index d6794a226..c865f1ec1 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -241,7 +241,7 @@ class OpenArmFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 6f4eef56c..fb466f85b 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -180,7 +180,7 @@ class Reachy2Robot(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() return obs_dict diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index b4d11fe3f..bc72a2b6a 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -187,7 +187,7 @@ class SOFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 01b4f330e..df0de8f19 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -324,7 +324,7 @@ class UnitreeG1(Robot): # Cameras - read images from ZMQ cameras for cam_name, cam in self._cameras.items(): - obs[cam_name] = cam.async_read() + obs[cam_name] = cam.read_latest() return obs diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index d3c44bf5a..d3f32b1c2 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs): cam.connect = MagicMock() cam.disconnect = MagicMock() cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) + cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) return cam From 5f15232271a81ee6be16cec1960e300f55f25466 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 22:46:12 +0100 Subject: [PATCH 41/43] chore: remove usernames + use entrypoints in docs, comments & sample commands (#2988) --- benchmarks/video/README.md | 84 +++++++++---------- docs/source/earthrover_mini_plus.mdx | 2 +- docs/source/hope_jr.mdx | 10 +-- docs/source/pi0.mdx | 2 +- docs/source/pi05.mdx | 2 +- docs/source/sarm.mdx | 8 +- docs/source/unitree_g1.mdx | 4 +- docs/source/walloss.mdx | 2 +- docs/source/xvla.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/rtc/eval_dataset.py | 20 ++--- examples/rtc/eval_with_real_robot.py | 6 +- .../v30/convert_dataset_v21_to_v30.py | 2 +- .../policies/sarm/compute_rabc_weights.py | 10 +-- .../policies/smolvla/modeling_smolvla.py | 4 +- src/lerobot/scripts/lerobot_edit_dataset.py | 32 +++---- src/lerobot/scripts/lerobot_replay.py | 2 +- 17 files changed, 97 insertions(+), 97 deletions(-) diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index 490a4b495..1feee69c4 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat For these reasons, we run this benchmark on four representative datasets: - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. -- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. -- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. -- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. +- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. +- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. +- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. Note: The datasets used for this benchmark need to be image datasets, not video datasets. @@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ + lerobot/aloha_mobile_shrimp_image \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 2 20 None \ @@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libsvtav1 \ --pix-fmt yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav` -| video_images_size_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | -| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | -| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | -| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | +| video_images_size_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ---------- | ------- | --------- | --------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | +| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | +| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | +| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | -| video_images_load_time_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ------- | ------- | -------- | ------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | -| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | -| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | -| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | +| video_images_load_time_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ------- | ------- | -------- | ------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | +| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | +| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | +| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | -| | | vcodec | pix_fmt | | | | -| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | -| | | libx264 | | libx265 | | libsvtav1 | -| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | -| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | -| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | -| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | -| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | -| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | -| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | -| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | -| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | -| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | -| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | -| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | +| | | vcodec | pix_fmt | | | | +| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | +| | | libx264 | | libx265 | | libsvtav1 | +| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | +| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | +| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | +| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | +| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | +| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | +| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | +| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | +| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | +| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | +| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | +| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index d8083336a..dd9c2ad2b 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -185,7 +185,7 @@ echo $HF_USER Use the standard recording command: ```bash -python src/lerobot/scripts/lerobot_record.py \ +lerobot-record \ --robot.type=earthrover_mini_plus \ --teleop.type=keyboard_rover \ --dataset.repo_id=your_username/dataset_name \ diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 856febb95..026cd084a 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -224,7 +224,7 @@ lerobot-record \ --teleop.port=/dev/tty.usbmodem1201 \ --teleop.id=right \ --teleop.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --dataset.single_task="Hand recording test with video data" \ --dataset.num_episodes=1 \ --dataset.episode_time_s=5 \ @@ -241,7 +241,7 @@ lerobot-replay \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ --robot.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_camera \ + --dataset.repo_id=/hand_record_test_with_camera \ --dataset.episode=0 ``` @@ -249,13 +249,13 @@ lerobot-replay \ ```bash lerobot-train \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --policy.type=act \ --output_dir=outputs/train/hopejr_hand \ --job_name=hopejr \ --policy.device=mps \ --wandb.enable=true \ - --policy.repo_id=nepyope/hand_test_policy + --policy.repo_id=/hand_test_policy ``` ### Evaluate @@ -270,7 +270,7 @@ lerobot-record \ --robot.side=right \ --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ --display_data=false \ - --dataset.repo_id=nepyope/eval_hopejr \ + --dataset.repo_id=/eval_hopejr \ --dataset.single_task="Evaluate hopejr hand policy" \ --dataset.num_episodes=10 \ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 93e0b4c88..879bbd16d 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -60,7 +60,7 @@ policy.type=pi0 For training π₀, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi0 \ --output_dir=./outputs/pi0_training \ diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index dbf118aa3..8abaca989 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -56,7 +56,7 @@ policy.type=pi05 Here's a complete training command for finetuning the base π₀.₅ model on your own dataset: ```bash -python src/lerobot/scripts/lerobot_train.py\ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi05 \ --output_dir=./outputs/pi05_training \ diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 65e49792b..cd488fe1f 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl Train with **no annotations** - uses linear progress from 0 to 1: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=single_stage \ @@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **dense annotations only** (sparse auto-generated): ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dense_only \ @@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **both sparse and dense annotations**: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dual \ @@ -468,7 +468,7 @@ This script: Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ --use_rabc=true \ diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index ea6bf54ad..4c5d28924 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -216,7 +216,7 @@ lerobot-teleoperate \ ### Record Dataset in Simulation ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=true \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ @@ -266,7 +266,7 @@ lerobot-teleoperate \ ### Record Dataset on Real Robot ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=false \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ diff --git a/docs/source/walloss.mdx b/docs/source/walloss.mdx index c0756c087..e9785cc93 100644 --- a/docs/source/walloss.mdx +++ b/docs/source/walloss.mdx @@ -45,7 +45,7 @@ policy.type=wall_x For training WallX, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=wall_x \ --output_dir=./outputs/wallx_training \ diff --git a/docs/source/xvla.mdx b/docs/source/xvla.mdx index dd7d1ef57..97e04d4ec 100644 --- a/docs/source/xvla.mdx +++ b/docs/source/xvla.mdx @@ -154,7 +154,7 @@ lerobot-train \ ```bash lerobot-train \ - --dataset.repo_id=pepijn223/bimanual-so100-handover-cube \ + --dataset.repo_id=/bimanual-so100-handover-cube \ --output_dir=./outputs/xvla_bimanual \ --job_name=xvla_so101_training \ --policy.path="lerobot/xvla-base" \ diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 8de5ba197..f7c47bec5 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=2 ``` """ diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 4652df107..613fd67d7 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -27,8 +27,8 @@ measuring consistency and ground truth alignment. Usage: # Basic usage with smolvla policy uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --rtc.max_guidance_weight=10.0 \ @@ -58,16 +58,16 @@ Usage: --device=cuda uv run python examples/rtc/eval_dataset.py \ - --policy.path=lipsop/reuben_pi0 \ - --dataset.repo_id=ReubenLim/so101_cube_in_cup \ + --policy.path=/reuben_pi0 \ + --dataset.repo_id=/so101_cube_in_cup \ --rtc.execution_horizon=8 \ --device=cuda # With torch.compile for faster inference (PyTorch 2.0+) # Note: CUDA graphs disabled by default due to in-place ops in denoising loop uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --use_torch_compile=true \ @@ -75,8 +75,8 @@ Usage: # With torch.compile on CUDA (CUDA graphs disabled by default) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=cuda \ --use_torch_compile=true \ @@ -84,8 +84,8 @@ Usage: # Enable CUDA graphs (advanced - may cause tensor aliasing errors) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --use_torch_compile=true \ --torch_compile_backend=inductor \ --torch_compile_mode=max-autotune \ diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 1470899d9..4c803eb7e 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py Usage: # Run RTC with Real robot with RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ @@ -41,7 +41,7 @@ Usage: # Run RTC with Real robot without RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=false \ --robot.type=so100_follower \ @@ -53,7 +53,7 @@ Usage: # Run RTC with Real robot with pi0.5 policy uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/pi05_check_rtc \ + --policy.path=/pi05_check_rtc \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 74be6bfa4..7be37a1b1 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -529,7 +529,7 @@ if __name__ == "__main__": type=str, required=True, help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " - "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + "(e.g. `lerobot/pusht`, `/aloha_sim_insertion_human`).", ) parser.add_argument( "--branch", diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py index 5b6ea6e9b..485c1096b 100644 --- a/src/lerobot/policies/sarm/compute_rabc_weights.py +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -27,18 +27,18 @@ Usage: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Faster computation with stride (compute every 5 frames, interpolate the rest) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --stride 5 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 5 @@ -714,12 +714,12 @@ Examples: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 10 """, diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 60b968a42..10544a949 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`): ```bash lerobot-train \ --policy.path=lerobot/smolvla_base \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` @@ -40,7 +40,7 @@ and an action expert. ```bash lerobot-train \ --policy.type=smolvla \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 06e256fa2..afdc95efd 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -24,100 +24,100 @@ When new_repo_id is specified, creates a new dataset. Usage Examples: Delete episodes 0, 2, and 5 from a dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Delete episodes and save to a new dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --new_repo_id lerobot/pusht_filtered \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Split dataset by fractions: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.8, "val": 0.2}' Split dataset by episode indices: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' Split into more than two splits: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' Merge multiple datasets: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_merged \ --operation.type merge \ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" Remove camera feature: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" Modify tasks - set a single task for all episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Pick up the cube and place it" Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Default task" \ --operation.episode_tasks '{"5": "Special task for episode 5"}' Convert image dataset to video format and save locally: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video Convert image dataset to video format and save with new repo_id: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video Convert image dataset to video format and push to hub: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video \ --push_to_hub true Show dataset information: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features true Show dataset information without feature details: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features false Using JSON config file: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --config_path path/to/edit_config.json """ diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c9a559d07..8e2a394b9 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=0 ``` From 2dd366436ed30ed9729b4f18076a54fec7ec589b Mon Sep 17 00:00:00 2001 From: Khalil Date: Thu, 19 Feb 2026 14:35:02 +0100 Subject: [PATCH 42/43] Fix gym-hil integration with the new LeRobot pipeline. (#2482) * Add GymHILAdapterProcessorStep for gym-hil environment integration * Fix action features in control loop for None teleop device with gym-hil * Finalize dataset before pushing to hub for visualization on the hub * Fix neutral action for gripper * fix pre-commit --- src/lerobot/processor/__init__.py | 2 ++ src/lerobot/processor/gym_action_processor.py | 8 +++++ src/lerobot/processor/hil_processor.py | 31 +++++++++++++++++++ src/lerobot/rl/gym_manipulator.py | 15 +++++++-- 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 164f7da03..0b63e1606 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -44,6 +44,7 @@ from .hil_processor import ( AddTeleopActionAsComplimentaryDataStep, AddTeleopEventsAsInfoStep, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, RewardClassifierProcessorStep, @@ -87,6 +88,7 @@ __all__ = [ "DoneProcessorStep", "EnvAction", "EnvTransition", + "GymHILAdapterProcessorStep", "GripperPenaltyProcessorStep", "hotswap_stats", "IdentityProcessorStep", diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 8fa8cfd86..4f225af92 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import to_tensor from .core import EnvAction, EnvTransition, PolicyAction +from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): torch_action = to_tensor(action, dtype=None) # Preserve original dtype new_transition[TransitionKey.ACTION] = torch_action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if TELEOP_ACTION_KEY in complementary_data: + teleop_action = complementary_data[TELEOP_ACTION_KEY] + if isinstance(teleop_action, EnvAction): + complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition def transform_features( diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 24b5628fa..34eaeed51 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): return features +@ProcessorStepRegistry.register("gym_hil_adapter_processor") +class GymHILAdapterProcessorStep(ProcessorStep): + """ + Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors. + + This step normalizes the `transition` object by: + 1. Copying `teleop_action` from `info` to `complementary_data`. + 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if TELEOP_ACTION_KEY in info: + complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] + + if "is_intervention" in info: + info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] + + transition[TransitionKey.INFO] = info + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ProcessorStep): diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 1c1cb752f..f5fcb7437 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -36,6 +36,7 @@ from lerobot.processor import ( DeviceProcessorStep, EnvTransition, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, MapDeltaActionToRobotActionStep, @@ -379,6 +380,7 @@ def make_processors( ] env_pipeline_steps = [ + GymHILAdapterProcessorStep(), Numpy2TorchActionProcessorStep(), VanillaObservationProcessorStep(), AddBatchDimensionProcessorStep(), @@ -608,7 +610,14 @@ def control_loop( dataset = None if cfg.mode == "record": - action_features = teleop_device.action_features + if teleop_device: + action_features = teleop_device.action_features + else: + action_features = { + "dtype": "float32", + "shape": (4,), + "names": ["delta_x", "delta_y", "delta_z", "gripper"], + } features = { ACTION: action_features, REWARD: {"dtype": "float32", "shape": (1,), "names": None}, @@ -656,7 +665,7 @@ def control_loop( # Create a neutral action (no movement) neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if use_gripper: - neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay + neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay # Use the new step function transition = step_env_and_process_transition( @@ -725,6 +734,8 @@ def control_loop( precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) if dataset is not None and cfg.dataset.push_to_hub: + logging.info("Finalizing dataset before pushing to hub") + dataset.finalize() logging.info("Pushing dataset to hub") dataset.push_to_hub() From 5865170d36442b907bb35f946e837eee18aafdf1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 20 Feb 2026 17:01:46 +0100 Subject: [PATCH 43/43] chore(deps): bump ceil datasets (#2946) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5431ada3..0ca1f0432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici dependencies = [ # Hugging Face dependencies - "datasets>=4.0.0,<4.2.0", + "datasets>=4.0.0,<5.0.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", "accelerate>=1.10.0,<2.0.0",