diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 4dce3121a..fd5e422b3 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -101,9 +101,11 @@ jobs: runs-on: group: aws-general-8-plus if: | - (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || - github.event_name == 'push' || - github.event_name == 'workflow_dispatch' + github.repository == 'huggingface/lerobot' && ( + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + ) outputs: image_tag: ${{ steps.set_tag.outputs.image_tag }} env: diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index a75ecc121..3f4ea3316 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -91,6 +91,7 @@ jobs: name: Build and Push Docker runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME }} env: diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 44d8c7034..8cc83843e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,13 +1,15 @@ # Installation -## Install [`miniforge`](https://conda-forge.org/download/) +This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). + +## Step 1: Install [`miniforge`](https://conda-forge.org/download/) ```bash wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" bash Miniforge3-$(uname)-$(uname -m).sh ``` -## Environment Setup +## Step 2: Environment Setup Create a virtual environment with Python 3.10, using conda: @@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -## Install LeRobot 🤗 +## Step 3: Install LeRobot 🤗 ### From Source diff --git a/pyproject.toml b/pyproject.toml index 4ad189e23..143462341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -361,9 +361,9 @@ ignore_errors = false module = "lerobot.cameras.*" ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.motors.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.motors.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index 1488cd89e..cbf1f11bf 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from .camera import Camera -from .configs import CameraConfig, ColorMode, Cv2Rotation +from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation from .utils import make_cameras_from_configs diff --git a/src/lerobot/cameras/configs.py b/src/lerobot/cameras/configs.py index 056eec314..987b74775 100644 --- a/src/lerobot/cameras/configs.py +++ b/src/lerobot/cameras/configs.py @@ -25,6 +25,10 @@ class ColorMode(str, Enum): RGB = "rgb" BGR = "bgr" + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.") + class Cv2Rotation(int, Enum): NO_ROTATION = 0 @@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum): ROTATE_180 = 180 ROTATE_270 = -90 + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.") + + +# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html +class Cv2Backends(int, Enum): + ANY = 0 + V4L2 = 200 + DSHOW = 700 + PVAPI = 800 + ANDROID = 1000 + AVFOUNDATION = 1200 + MSMF = 1400 + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.") + @dataclass(kw_only=True) class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index d581e1425..10b3f21da 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 # type: ignore # TODO: add type stubs for OpenCV -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera -from ..utils import get_cv2_backend, get_cv2_rotation +from ..utils import get_cv2_rotation from .configuration_opencv import ColorMode, OpenCVCameraConfig # NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, @@ -117,7 +118,7 @@ class OpenCVCamera(Camera): self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) - self.backend: int = get_cv2_backend() + self.backend: int = config.backend if self.height and self.width: self.capture_width, self.capture_height = self.width, self.height @@ -132,6 +133,7 @@ class OpenCVCamera(Camera): """Checks if the camera is currently connected and opened.""" return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened() + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the OpenCV camera specified in the configuration. @@ -148,8 +150,6 @@ class OpenCVCamera(Camera): ConnectionError: If the specified camera index/path is not found or fails to open. RuntimeError: If the camera opens but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") # Use 1 thread for OpenCV operations to avoid potential conflicts or # blocking in multi-threaded applications, especially during data collection. @@ -178,6 +178,7 @@ class OpenCVCamera(Camera): logger.info(f"{self} connected.") + @check_if_not_connected def _configure_capture_settings(self) -> None: """ Applies the specified FOURCC, FPS, width, and height settings to the connected camera. @@ -197,8 +198,6 @@ class OpenCVCamera(Camera): to the requested value. DeviceNotConnectedError: If the camera is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") # Set FOURCC first (if specified) as it can affect available FPS/resolution options if self.config.fourcc is not None: @@ -348,6 +347,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -374,9 +374,6 @@ class OpenCVCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -490,6 +487,7 @@ class OpenCVCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -512,8 +510,6 @@ class OpenCVCamera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -533,6 +529,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -548,8 +545,6 @@ class OpenCVCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/opencv/configuration_opencv.py b/src/lerobot/cameras/opencv/configuration_opencv.py index 37a42861c..8ae57fe3c 100644 --- a/src/lerobot/cameras/opencv/configuration_opencv.py +++ b/src/lerobot/cameras/opencv/configuration_opencv.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from pathlib import Path -from ..configs import CameraConfig, ColorMode, Cv2Rotation +from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation -__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"] +__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"] @CameraConfig.register_subclass("opencv") @@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig): rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. warmup_s: Time reading frames before returning from connect (in seconds) fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect). + backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY. Note: - Only 3-channel color output (RGB/BGR) is currently supported. @@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig): rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION warmup_s: int = 1 fourcc: str | None = None + backend: Cv2Backends = Cv2Backends.ANY def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) + self.backend = Cv2Backends(self.backend) if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4): raise ValueError( diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py index ca6db4f03..b40bfe71b 100644 --- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig): f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided." ) - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 5cede466d..0c1dc43d8 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" import cv2 # type: ignore # TODO: add type stubs for OpenCV import numpy as np # type: ignore # TODO: add type stubs for numpy +from lerobot.utils.decorators import check_if_not_connected from lerobot.utils.import_utils import _reachy2_sdk_available if TYPE_CHECKING or _reachy2_sdk_available: @@ -123,6 +124,7 @@ class Reachy2Camera(Camera): """ raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -136,9 +138,6 @@ class Reachy2Camera(Camera): """ start_time = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -184,6 +183,7 @@ class Reachy2Camera(Camera): return frame + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Same as read() @@ -197,11 +197,10 @@ class Reachy2Camera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") return self.read() + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -219,8 +218,6 @@ class Reachy2Camera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.latest_frame is None or self.latest_timestamp is None: raise RuntimeError(f"{self} has not captured any frames yet.") @@ -233,6 +230,7 @@ class Reachy2Camera(Camera): return self.latest_frame + @check_if_not_connected def disconnect(self) -> None: """ Stops the background read thread (if running). @@ -240,8 +238,6 @@ class Reachy2Camera(Camera): Raises: DeviceNotConnectedError: If the camera is already disconnected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} not connected.") if self.cam_manager is not None: self.cam_manager.disconnect() diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index e47f25381..d599cdce0 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -30,7 +30,8 @@ try: except Exception as e: logging.info(f"Could not import realsense: {e}") -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -152,6 +153,7 @@ class RealSenseCamera(Camera): """Checks if the camera pipeline is started and streams are active.""" return self.rs_pipeline is not None and self.rs_profile is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the RealSense camera specified in the configuration. @@ -169,8 +171,6 @@ class RealSenseCamera(Camera): ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. RuntimeError: If the pipeline starts but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") self.rs_pipeline = rs.pipeline() rs_config = rs.config() @@ -290,6 +290,7 @@ class RealSenseCamera(Camera): if self.use_depth: rs_config.enable_stream(rs.stream.depth) + @check_if_not_connected def _configure_capture_settings(self) -> None: """Sets fps, width, and height from device stream if not already configured. @@ -299,8 +300,6 @@ class RealSenseCamera(Camera): Raises: DeviceNotConnectedError: If device is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") if self.rs_profile is None: raise RuntimeError(f"{self}: rs_profile must be initialized before use.") @@ -320,6 +319,7 @@ class RealSenseCamera(Camera): self.width, self.height = actual_width, actual_height self.capture_width, self.capture_height = actual_width, actual_height + @check_if_not_connected def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]: """ Reads a single frame (depth) synchronously from the camera. @@ -345,9 +345,6 @@ class RealSenseCamera(Camera): f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -374,6 +371,7 @@ class RealSenseCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. @@ -403,9 +401,6 @@ class RealSenseCamera(Camera): f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -534,6 +529,7 @@ class RealSenseCamera(Camera): self.new_frame_event.clear() # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame data (color) asynchronously. @@ -556,8 +552,6 @@ class RealSenseCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread died unexpectedly or another error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -578,6 +572,7 @@ class RealSenseCamera(Camera): return frame # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). @@ -593,8 +588,6 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py index a094128bc..71b083b00 100644 --- a/src/lerobot/cameras/realsense/configuration_realsense.py +++ b/src/lerobot/cameras/realsense/configuration_realsense.py @@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) values = (self.fps, self.width, self.height) if any(v is not None for v in values) and any(v is None for v in values): diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index c0e7b6284..7fb2c3bb1 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import cast from lerobot.utils.import_utils import make_device_from_device_class @@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None: return int(cv2.ROTATE_90_COUNTERCLOCKWISE) else: return None - - -def get_cv2_backend() -> int: - import cv2 - - if platform.system() == "Windows": - return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION - # elif platform.system() == "Darwin": # macOS - # return cv2.CAP_AVFOUNDATION - else: # Linux and others - return int(cv2.CAP_ANY) diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index f29e16a28..16523b50a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -34,7 +34,8 @@ import cv2 import numpy as np from numpy.typing import NDArray -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -104,6 +105,7 @@ class ZMQCamera(Camera): """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """Connect to ZMQ camera server. @@ -111,8 +113,6 @@ class ZMQCamera(Camera): warmup (bool): If True, waits for the camera to provide at least one valid frame before returning. Defaults to True. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") logger.info(f"Connecting to {self}...") @@ -211,6 +211,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -228,9 +229,6 @@ class ZMQCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -301,6 +299,7 @@ class ZMQCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -317,8 +316,6 @@ class ZMQCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread is not running. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -335,6 +332,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -350,8 +348,6 @@ class ZMQCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 4e7732cfc..13690e14c 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) if self.timeout_ms <= 0: raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 96c5cf102..d20dae8ea 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -112,6 +112,7 @@ class LiberoEnv(gym.Env): visualization_height: int = 480, init_states: bool = True, episode_index: int = 0, + n_envs: int = 1, camera_name_mapping: dict[str, str] | None = None, num_steps_wait: int = 10, control_mode: str = "relative", @@ -145,7 +146,9 @@ class LiberoEnv(gym.Env): self.episode_length = episode_length # Load once and keep self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None - self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`. + + self.init_state_id = self.episode_index # tie each sub-env to a fixed init state self._env = self._make_envs_task(task_suite, self.task_id) default_steps = 500 @@ -295,7 +298,8 @@ class LiberoEnv(gym.Env): self._env.seed(seed) raw_obs = self._env.reset() if self.init_states and self._init_states is not None: - raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)]) + self.init_state_id += self._reset_stride # Change init_state_id when reset # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. @@ -373,6 +377,7 @@ def _make_env_fns( init_states=init_states, episode_length=episode_length, episode_index=episode_index, + n_envs=n_envs, control_mode=control_mode, **local_kwargs, ) diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 02bba454f..3410cb28a 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -221,7 +221,7 @@ class RangeFinderGUI: self.bus = bus self.groups = groups if groups is not None else {"all": list(bus.motors)} - self.group_names = list(groups) + self.group_names = list(self.groups) self.current_group = self.group_names[0] if not bus.is_connected: @@ -230,18 +230,20 @@ class RangeFinderGUI: self.calibration = bus.read_calibration() self.res_table = bus.model_resolution_table self.present_cache = { - m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + m: bus.read("Present_Position", m, normalize=False) + for motors in self.groups.values() + for m in motors } pygame.init() self.font = pygame.font.Font(None, FONT_SIZE) - label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms) self.label_pad = label_pad width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 self.controls_bottom = 10 + SAVE_H self.base_y = self.controls_bottom + TOP_GAP - height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40 self.screen = pygame.display.set_mode((width, height)) pygame.display.set_caption("Motors range finder") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index c79f8d17e..ae619f159 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -23,6 +23,7 @@ from copy import deepcopy from functools import cached_property from typing import TYPE_CHECKING, Any, TypedDict +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: @@ -36,7 +37,6 @@ else: import numpy as np -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import enter_pressed, move_cursor_up @@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase): """Check if the CAN bus is connected.""" return self._is_connected and self.canbus is not None + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """ Open the CAN bus and initialize communication. @@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: handshake: If True, ping all motors to verify they're present """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected." - ) try: # Auto-detect interface type based on port name @@ -211,6 +208,9 @@ class DamiaoMotorsBus(MotorsBusBase): logger.info("Starting handshake with motors...") # Drain any pending messages + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + while self.canbus.recv(timeout=0.01): pass @@ -246,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase): ) logger.info("Handshake successful. All motors ready.") + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """ Close the CAN bus connection. @@ -253,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: disable_torque: If True, disable torque on all motors before disconnecting """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") if disable_torque: try: @@ -283,6 +282,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -341,6 +344,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -356,6 +363,10 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: CAN message if received, None otherwise """ + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: start_time = time.time() messages_seen = [] @@ -394,10 +405,13 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: Dictionary mapping recv_id to CAN message """ - responses = {} + responses: dict[int, can.Message] = {} expected_set = set(expected_recv_ids) start_time = time.time() + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: # 100us poll timeout @@ -461,6 +475,9 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) motor_type = self._motor_types[motor_name] + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) @@ -488,6 +505,9 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Step 1: Send all MIT control commands for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): motor_id = self._get_motor_id(motor) @@ -562,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase): except Exception as e: logger.warning(f"Failed to decode response from {motor}: {e}") + @check_if_not_connected def read(self, data_name: str, motor: str) -> Value: """Read a value from a single motor. Positions are always in degrees.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Refresh motor to get latest state msg = self._refresh_motor(motor) @@ -595,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase): raise ValueError(f"Unknown data_name: {data_name}") return mapping[data_name] + @check_if_not_connected def write( self, data_name: str, @@ -605,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase): Write a value to a single motor. Positions are always in degrees. Can write 'Goal_Position', 'Kp', or 'Kd'. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if data_name in ("Kp", "Kd"): self._gains[motor][data_name.lower()] = float(value) @@ -656,6 +674,10 @@ class DamiaoMotorsBus(MotorsBusBase): def _batch_refresh(self, motors: list[str]) -> None: """Internal helper to refresh a list of motors and update cache.""" + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Send refresh commands for motor in motors: motor_id = self._get_motor_id(motor) @@ -678,10 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + @check_if_not_connected + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ + if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): @@ -690,6 +714,8 @@ class DamiaoMotorsBus(MotorsBusBase): elif data_name == "Goal_Position": # Step 1: Send all MIT control commands recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") for motor, value_degrees in values.items(): motor_id = self._get_motor_id(motor) motor_name = self._get_motor_name(motor) @@ -732,9 +758,9 @@ class DamiaoMotorsBus(MotorsBusBase): def record_ranges_of_motion( self, - motors: NameOrID | list[NameOrID] | None = None, + motors: str | list[str] | None = None, display_values: bool = True, - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + ) -> tuple[dict[str, Value], dict[str, Value]]: """ Interactively record the min/max values of each motor in degrees. diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index c6752ee96..bca455dc5 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus): for motor, m in self.motors.items(): calibration[motor] = MotorCalibration( id=m.id, - drive_mode=drive_modes[motor], - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + drive_mode=int(drive_modes[motor]), + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus): if cache: self.calibration = calibration_dict - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) @@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) @@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus): On Dynamixel Motors: Present_Position = Actual_Position + Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None return {id_: data[0] for id_, data in data_list.items()} diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7ce3388b6..58a65310d 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus): self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch - self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler, scs.PortHandler ) self.packet_handler = scs.PacketHandler(protocol_version) @@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus): calibration[motor] = MotorCalibration( id=m.id, drive_mode=0, - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus): On Feetech Motors: Present_Position = Actual_Position - Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus): return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) @@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Lock") self._write(addr, length, motor, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry) @@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus): def _broadcast_ping(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs - data_list = {} + data_list: dict[int, int] = {} status_length = 6 @@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus): if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index c04f718b6..bc3ffb7e2 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -23,6 +23,7 @@ from __future__ import annotations import abc import logging +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC): pass @abc.abstractmethod - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """Write values to multiple motors.""" pass @@ -179,15 +180,16 @@ class Motor: class PortHandler(Protocol): - def __init__(self, port_name): - self.is_open: bool - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool - self.port_name: str - self.ser: serial.Serial + is_open: bool + baudrate: int + packet_start_time: float + packet_timeout: float + tx_time_per_byte: float + is_using: bool + port_name: str + ser: serial.Serial + + def __init__(self, port_name: str) -> None: ... def openPort(self): ... def closePort(self): ... @@ -240,19 +242,22 @@ class PacketHandler(Protocol): def regWriteTxRx(self, port, id, address, length, data): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + def broadcastPing(self, port): ... class GroupSyncRead(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.last_result: bool - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + last_result: bool + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id): ... def removeParam(self, id): ... @@ -265,15 +270,17 @@ class GroupSyncRead(Protocol): class GroupSyncWrite(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id, data): ... def removeParam(self, id): ... @@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motor_model(self, motor: NameOrID) -> int: + def _get_motor_model(self, motor: NameOrID) -> str: if isinstance(motor, str): return self.motors[motor].model elif isinstance(motor, int): @@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]: if motors is None: return list(self.motors) elif isinstance(motors, str): return [motors] - elif isinstance(motors, list): - return motors.copy() + elif isinstance(motors, int): + return [self._id_to_name(motors)] + elif isinstance(motors, Sequence): + return [m if isinstance(m, str) else self._id_to_name(m) for m in motors] else: raise TypeError(motors) - def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]: if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): @@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase): pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: """Enable torque on selected motors. Args: - motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`. + Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. """ pass @contextmanager - def torque_disabled(self, motors: int | str | list[str] | None = None): + def torque_disabled(self, motors: str | list[str] | None = None): """Context-manager that guarantees torque is re-enabled. This helper is useful to temporarily disable torque when configuring motors. @@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase): """ pass - def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None: """Restore factory calibration for the selected motors. Homing offset is set to ``0`` and min/max position limits are set to the full usable range. The in-memory :pyattr:`calibration` is cleared. Args: - motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default) resets every motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - for motor in motors: + for motor in motor_names: model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 self.write("Homing_Offset", motor, 0, normalize=False) @@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase): self.calibration = {} - def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + def set_half_turn_homings( + self, motors: NameOrID | Sequence[NameOrID] | None = None + ) -> dict[NameOrID, Value]: """Centre each motor range around its current position. The function computes and writes a homing offset such that the present position becomes exactly one @@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase): motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). Returns: - dict[NameOrID, Value]: Mapping *motor → written homing offset*. + dict[str, Value]: Mapping *motor name → written homing offset*. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - self.reset_calibration(motors) - actual_positions = self.sync_read("Present_Position", motors, normalize=False) + self.reset_calibration(motor_names) + actual_positions = self.sync_read("Present_Position", motor_names, normalize=False) homing_offsets = self._get_half_turn_homings(actual_positions) for motor, offset in homing_offsets.items(): self.write("Homing_Offset", motor, offset) @@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase): pass def record_ranges_of_motion( - self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: """Interactively record the min/max encoder values of each motor. Move the joints by hand (with torque disabled) while the method streams live positions. Press @@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase): display_values (bool, optional): When `True` (default) a live table is printed to the console. Returns: - tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the extreme values observed for each motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - start_positions = self.sync_read("Present_Position", motors, normalize=False) + start_positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = start_positions.copy() maxes = start_positions.copy() user_pressed_enter = False while not user_pressed_enter: - positions = self.sync_read("Present_Position", motors, normalize=False) + positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} if display_values: print("\n-------------------------------------------") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") - for motor in motors: + for motor in motor_names: print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") if enter_pressed(): @@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase): if display_values and not user_pressed_enter: # Move cursor up to overwrite the previous output - move_cursor_up(len(motors) + 3) + move_cursor_up(len(motor_names) + 3) - same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]] if same_min_max: raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") @@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: - return + return None if self._is_error(error): if raise_on_error: raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: - return + return None return model_number @@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase): err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - id_value = self._decode_sign(data_name, {id_: value}) + decoded = self._decode_sign(data_name, {id_: value}) if normalize and data_name in self.normalized_data: - id_value = self._normalize(id_value) + normalized = self._normalize(decoded) + return normalized[id_] - return id_value[id_] + return decoded[id_] def _read( self, @@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase): num_retry: int = 0, raise_on_error: bool = True, err_msg: str = "", - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) + int_value = int(value) if normalize and data_name in self.normalized_data: - value = self._unnormalize({id_: value})[id_] + int_value = self._unnormalize({id_: value})[id_] - value = self._encode_sign(data_name, {id_: value})[id_] + int_value = self._encode_sign(data_name, {id_: int_value})[id_] - err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries." + self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( self, @@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase): def sync_read( self, data_name: str, - motors: str | list[str] | None = None, + motors: NameOrID | Sequence[NameOrID] | None = None, *, normalize: bool = True, num_retry: int = 0, @@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase): Args: data_name (str): Register name. - motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor. normalize (bool, optional): Normalisation flag. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. @@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase): addr, length = get_address(self.model_ctrl_table, model, data_name) err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - ids_values, _ = self._sync_read( + raw_ids_values, _ = self._sync_read( addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg ) - ids_values = self._decode_sign(data_name, ids_values) + decoded = self._decode_sign(data_name, raw_ids_values) if normalize and data_name in self.normalized_data: - ids_values = self._normalize(ids_values) + normalized = self._normalize(decoded) + return {self._id_to_name(id_): value for id_, value in normalized.items()} - return {self._id_to_name(id_): value for id_, value in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in decoded.items()} def _sync_read( self, @@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase): num_retry (int, optional): Retry attempts. Defaults to `0`. """ - ids_values = self._get_ids_values_dict(values) - models = [self._id_to_model(id_) for id_ in ids_values] + raw_ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in raw_ids_values] if self._has_different_ctrl_tables: assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) + int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()} if normalize and data_name in self.normalized_data: - ids_values = self._unnormalize(ids_values) + int_ids_values = self._unnormalize(raw_ids_values) - ids_values = self._encode_sign(data_name, ids_values) + int_ids_values = self._encode_sign(data_name, int_ids_values) - err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries." + self._sync_write( + addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) def _sync_write( self, diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index 8d42bfdb7..a77e066cf 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep): # copy over non-STATE features for ft, feats in features.items(): - if ft != PipelineFeatureType.STATE: + if ft != FeatureType.STATE: new_features[ft] = feats.copy() # rebuild STATE features @@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep): # add our new flattened state state_feats[OBS_STATE] = PolicyFeature( - key=OBS_STATE, + type=FeatureType.STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] - dtype="float32", - description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), ) - new_features[PipelineFeatureType.STATE] = state_feats + new_features[FeatureType.STATE] = state_feats return new_features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 97ec716ff..8de376928 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): Args: save_directory: The directory where the pipeline will be saved. If None, saves to HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. - repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`. push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. card_kwargs: Additional arguments passed to the card template to customize the card. config_filename: The name of the JSON configuration file. If None, a name is diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 466eb07e5..2e3885e67 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_openarm_follower import BiOpenArmFollowerConfig @@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 09f849772..28c58b898 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_so_follower import BiSOFollowerConfig @@ -96,6 +97,7 @@ class BiSOFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -116,6 +118,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -129,6 +132,7 @@ class BiSOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { @@ -148,6 +152,7 @@ class BiSOFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c221afd10..d6794a226 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -119,6 +119,7 @@ class OpenArmFollower(Robot): """Check if robot is connected.""" return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the robot and optionally calibrate. @@ -126,8 +127,6 @@ class OpenArmFollower(Robot): We assume that at connection time, the arms are in a safe rest position, and torque can be safely disabled to run calibration if needed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -219,6 +218,7 @@ class OpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Get current observation from robot including position, velocity, and torque. @@ -228,9 +228,6 @@ class OpenArmFollower(Robot): """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict: dict[str, Any] = {} states = self.bus.sync_read_all_states() @@ -253,6 +250,7 @@ class OpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -272,8 +270,6 @@ class OpenArmFollower(Robot): Returns: The action actually sent (potentially clipped) """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -333,10 +329,9 @@ class OpenArmFollower(Robot): return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): """Disconnect from robot.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus self.bus.disconnect(self.config.disable_torque_on_disconnect) diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index e9ce27123..1ee589bda 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -40,7 +40,7 @@ class SOFollowerConfig: cameras: dict[str, CameraConfig] = field(default_factory=dict) # Set to `True` for backward compatibility with previous policies/dataset - use_degrees: bool = False + use_degrees: bool = True @RobotConfig.register_subclass("so101_follower") diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 2ca9c520d..7c222ac6c 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -109,11 +109,14 @@ Using JSON config file: --config_path path/to/edit_config.json """ +import abc import logging import shutil from dataclasses import dataclass from pathlib import Path +import draccus + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, @@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging @dataclass -class DeleteEpisodesConfig: - type: str = "delete_episodes" +class OperationConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@OperationConfig.register_subclass("delete_episodes") +@dataclass +class DeleteEpisodesConfig(OperationConfig): episode_indices: list[int] | None = None +@OperationConfig.register_subclass("split") @dataclass -class SplitConfig: - type: str = "split" +class SplitConfig(OperationConfig): splits: dict[str, float | list[int]] | None = None +@OperationConfig.register_subclass("merge") @dataclass -class MergeConfig: - type: str = "merge" +class MergeConfig(OperationConfig): repo_ids: list[str] | None = None +@OperationConfig.register_subclass("remove_feature") @dataclass -class RemoveFeatureConfig: - type: str = "remove_feature" +class RemoveFeatureConfig(OperationConfig): feature_names: list[str] | None = None +@OperationConfig.register_subclass("modify_tasks") @dataclass -class ModifyTasksConfig: - type: str = "modify_tasks" +class ModifyTasksConfig(OperationConfig): new_task: str | None = None episode_tasks: dict[str, str] | None = None +@OperationConfig.register_subclass("convert_image_to_video") @dataclass -class ConvertImageToVideoConfig: - type: str = "convert_image_to_video" +class ConvertImageToVideoConfig(OperationConfig): output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -177,14 +187,7 @@ class ConvertImageToVideoConfig: @dataclass class EditDatasetConfig: repo_id: str - operation: ( - DeleteEpisodesConfig - | SplitConfig - | MergeConfig - | RemoveFeatureConfig - | ModifyTasksConfig - | ConvertImageToVideoConfig - ) + operation: OperationConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: - raise ValueError( - f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" - ) + available = ", ".join(OperationConfig.get_known_choices()) + raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") def main() -> None: diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index c4383293f..74b0c9b83 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader from ..teleoperator import Teleoperator @@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: action_dict = {} @@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 90bf2a92d..e84ac6f50 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,7 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig -from lerobot.utils.decorators import check_if_not_connected +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index edf4d7090..d9eaabe0f 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_openarm_leader import OpenArmLeaderConfig @@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator): """Check if teleoperator is connected.""" return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the teleoperator. @@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator): For manual control, we disable torque after connecting so the arm can be moved by hand. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get current action from the leader arm. @@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator): Reads all motor states (pos/vel/torque) in one CAN refresh cycle. """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") action_dict: dict[str, Any] = {} @@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + @check_if_not_connected def disconnect(self) -> None: """Disconnect from teleoperator.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus # For manual control, ensure torque is disabled before disconnecting diff --git a/src/lerobot/teleoperators/so_leader/config_so_leader.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py index dd55196d7..2b4f782a7 100644 --- a/src/lerobot/teleoperators/so_leader/config_so_leader.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -28,7 +28,7 @@ class SOLeaderConfig: port: str # Whether to use degrees for angles - use_degrees: bool = False + use_degrees: bool = True @TeleoperatorConfig.register_subclass("so101_leader") diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py new file mode 100644 index 000000000..bf7386b52 --- /dev/null +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import draccus +import pytest + +from lerobot.scripts.lerobot_edit_dataset import ( + ConvertImageToVideoConfig, + DeleteEpisodesConfig, + EditDatasetConfig, + MergeConfig, + ModifyTasksConfig, + OperationConfig, + RemoveFeatureConfig, + SplitConfig, +) + + +def parse_cfg(cli_args: list[str]) -> EditDatasetConfig: + """Helper to parse CLI args into an EditDatasetConfig via draccus.""" + return draccus.parse(EditDatasetConfig, args=cli_args) + + +class TestOperationTypeParsing: + """Test that --operation.type correctly selects the right config subclass.""" + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_operation_type_resolves_correct_class(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + assert isinstance(cfg.operation, expected_cls), ( + f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" + ) + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_get_choice_name_roundtrips(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) + assert resolved_name == type_name