Merge branch 'main' into feat/robotwin-benchmark

This commit is contained in:
Pepijn
2026-04-16 18:57:39 +02:00
committed by GitHub
42 changed files with 1581 additions and 423 deletions
+4 -22
View File
@@ -2,11 +2,6 @@
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
## Type / Scope
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
- **Scope**: (optional — name of module or package affected)
## Summary / Motivation
- One-paragraph description of what changes and why.
@@ -19,28 +14,14 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
## What changed
- Short, concrete bullets of the modifications (files/behaviour).
- Short, concrete bullets explaining the functional changes (how the behavior or output differs now).
- Short note if this introduces breaking changes and migration steps.
## How was this tested (or how to run locally)
- Tests added: list new tests or test files.
- Tests added: list new tests or test files. `pytest -q tests/ -k <keyword>`
- Manual checks / dataset runs performed.
- Instructions for the reviewer
Example:
- Ran the relevant tests:
```bash
pytest -q tests/ -k <keyword>
```
- Reproduce with a quick example or CLI (if applicable):
```bash
lerobot-train --some.option=true
```
- Instructions for the reviewer for reproducing with a quick example or CLI (if applicable)
## Checklist (required before merge)
@@ -48,6 +29,7 @@ Example:
- [ ] All tests pass locally (`pytest`)
- [ ] Documentation updated
- [ ] CI is green
- [ ] Community Review: I have reviewed another contributor's open PR and linked it here: # (insert PR number/link)
## Reviewer notes
@@ -33,7 +33,7 @@ jobs:
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
with:
package_name: lerobot
secrets:
+4 -1
View File
@@ -78,6 +78,9 @@ Use the templates for required fields and examples.
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
One member of the LeRobot team will then review your contribution.
> [!IMPORTANT]
> Community Review Policy: To help scale our efforts and foster a collaborative environment, we ask contributors to review at least one other person's open PR before their own receives attention. This shared responsibility multiplies our review capacity and helps everyone's code get merged faster!
Once you have submitted your PR and completed a peer review, a member of the LeRobot team will review your contribution.
Thank you for contributing to LeRobot!
+6
View File
@@ -32,6 +32,12 @@ Once youve gathered enough trajectories, youll train a neural network to i
If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
<Tip>
Want to quickly get the right commands for your setup? The [quickstart notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) lets you configure your robot once and generates all the commands below ready to paste.
</Tip>
## Set up and Calibrate
If you haven't yet set up and calibrated your robot and teleop device, please do so by following the robot-specific tutorial.
+342
View File
@@ -0,0 +1,342 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🤗 LeRobot Quickstart\n",
"\n",
"Calibration → teleoperation → data collection → training → evaluation.\n",
"\n",
"Install the required dependencies: `pip install -e .[notebook,dataset,training,viz,hardware]`.\n",
"\n",
"**How to use:**\n",
"1. Edit the **Configuration** cell with your settings.\n",
"2. Run all cells (`Run All`).\n",
"3. Each section prints a ready-to-paste terminal command - copy it and run it.\n",
"\n",
"Each setup is different, please refer to the [LeRobot documentation](https://huggingface.co/docs/lerobot/il_robots) for more details on each step and available options. <br>\n",
"Feel free to make this notebook your own and adapt it to your needs!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _cameras_arg(cameras: dict) -> str:\n",
" if not cameras:\n",
" return \"\"\n",
" entries = [f\"{n}: {{{', '.join(f'{k}: {v}' for k, v in cfg.items())}}}\" for n, cfg in cameras.items()]\n",
" return \"{ \" + \", \".join(entries) + \" }\"\n",
"\n",
"\n",
"def print_cmd(*parts: str) -> None:\n",
" \"\"\"Print a shell command with line continuations, skipping empty parts.\"\"\"\n",
" non_empty = [p for p in parts if p]\n",
" print(\" \\\\\\n \".join(non_empty))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Configuration\n",
"\n",
"Edit this cell, then **Run All** to generate all commands below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Robot (follower) - run `lerobot-find-port` to discover the port\n",
"ROBOT_TYPE = \"so101_follower\"\n",
"ROBOT_PORT = \"/dev/ttyACM0\"\n",
"ROBOT_ID = \"my_follower_arm\"\n",
"\n",
"# Teleop (leader) - run `lerobot-find-port` to discover the port\n",
"TELEOP_TYPE = \"so101_leader\"\n",
"TELEOP_PORT = \"/dev/ttyACM1\"\n",
"TELEOP_ID = \"my_leader_arm\"\n",
"\n",
"# Cameras - set to {} to disable\n",
"# Run `lerobot-find-cameras opencv` to list available cameras and their indices\n",
"CAMERAS = {\n",
" \"top\": {\"type\": \"opencv\", \"index_or_path\": 2, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
" \"wrist\": {\"type\": \"opencv\", \"index_or_path\": 4, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
"}\n",
"\n",
"# Dataset\n",
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
"DATASET_NAME = \"my_so101_dataset\"\n",
"TASK_DESCRIPTION = \"pick and place the block\"\n",
"NUM_EPISODES = 10\n",
"\n",
"# Training\n",
"POLICY_TYPE = \"act\" # act, diffusion, smolvla, ...\n",
"POLICY_DEVICE = \"cuda\" # cuda / cpu / mps\n",
"TRAIN_STEPS = 10_000\n",
"SAVE_FREQ = 2_000\n",
"OUTPUT_DIR = f\"outputs/train/{DATASET_NAME}\"\n",
"\n",
"# Inference - Hub repo ID or local checkpoint path\n",
"# e.g. set to f\"{OUTPUT_DIR}/checkpoints/last\" to use a local checkpoint\n",
"POLICY_PATH = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
"LAST_CHECKPOINT_PATH = f\"{OUTPUT_DIR}/checkpoints/last\"\n",
"\n",
"# Derived\n",
"DATASET_REPO_ID = f\"{HF_USER}/{DATASET_NAME}\"\n",
"DATASET_ROOT = f\"data/{DATASET_NAME}\"\n",
"POLICY_REPO_ID = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
"EVAL_REPO_ID = f\"{HF_USER}/eval_{DATASET_NAME}\"\n",
"CAMERAS_ARG = _cameras_arg(CAMERAS)\n",
"CAMERAS_FLAG = f'--robot.cameras=\"{CAMERAS_ARG}\"' if CAMERAS_ARG else \"\"\n",
"\n",
"print(f\"Robot : {ROBOT_TYPE} @ {ROBOT_PORT}\")\n",
"print(f\"Teleop : {TELEOP_TYPE} @ {TELEOP_PORT}\")\n",
"print(f\"Cameras: {list(CAMERAS) or 'none'}\")\n",
"print(f\"Dataset: {DATASET_REPO_ID} ({NUM_EPISODES} episodes) saved to {DATASET_ROOT}\")\n",
"print(f\"Policy : {POLICY_TYPE} -> {POLICY_REPO_ID}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 1. Calibration\n",
"\n",
"Run once per arm before first use."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Follower\n",
"print_cmd(\n",
" \"lerobot-calibrate\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Leader\n",
"print_cmd(\n",
" \"lerobot-calibrate\",\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 2. Teleoperation\n",
"\n",
"See the [teleoperation docs](https://huggingface.co/docs/lerobot/il_robots#teleoperate) and the [cameras guide](https://huggingface.co/docs/lerobot/cameras) for more options."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-teleoperate\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" \"--display_data=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 3. Record Dataset\n",
"\n",
"See the [recording docs](https://huggingface.co/docs/lerobot/il_robots#record-a-dataset) for tips on gathering good data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
" \"--display_data=true\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Resume a previously interrupted recording session\n",
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--dataset.root={DATASET_ROOT}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
" \"--display_data=true\",\n",
" \"--resume=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 4. Train Policy\n",
"\n",
"See the [training docs](https://huggingface.co/docs/lerobot/il_robots#train-a-policy) for configuration options and tips."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-train\",\n",
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
" f\"--policy.type={POLICY_TYPE}\",\n",
" f\"--policy.device={POLICY_DEVICE}\",\n",
" f\"--policy.repo_id={POLICY_REPO_ID}\",\n",
" f\"--output_dir={OUTPUT_DIR}\",\n",
" f\"--steps={TRAIN_STEPS}\",\n",
" f\"--save_freq={SAVE_FREQ}\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Resume a previously interrupted training session\n",
"print_cmd(\n",
" \"lerobot-train\",\n",
" f\"--config_path={LAST_CHECKPOINT_PATH}/pretrained_model/train_config.json\",\n",
" \"--resume=true\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## 5. Inference\n",
"\n",
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
"\n",
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_cmd(\n",
" \"lerobot-record\",\n",
" f\"--policy.path={POLICY_PATH}\",\n",
" f\"--robot.type={ROBOT_TYPE}\",\n",
" f\"--robot.port={ROBOT_PORT}\",\n",
" f\"--robot.id={ROBOT_ID}\",\n",
" CAMERAS_FLAG,\n",
" f\"--teleop.type={TELEOP_TYPE}\",\n",
" f\"--teleop.port={TELEOP_PORT}\",\n",
" f\"--teleop.id={TELEOP_ID}\",\n",
" f\"--dataset.repo_id={EVAL_REPO_ID}\",\n",
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
" \"--dataset.streaming_encoding=true\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot (3.12.3)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
+14 -8
View File
@@ -108,9 +108,9 @@ training = [
"wandb>=0.24.0,<0.25.0",
]
hardware = [
"pynput>=1.7.8,<1.9.0",
"pyserial>=3.5,<4.0",
"deepdiff>=7.0.1,<9.0.0",
"lerobot[pynput-dep]",
"lerobot[pyserial-dep]",
"lerobot[deepdiff-dep]",
]
viz = [
"rerun-sdk>=0.24.0,<0.27.0",
@@ -136,10 +136,14 @@ scipy-dep = ["scipy>=1.14.0,<2.0.0"]
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
pyserial-dep = ["pyserial>=3.5,<4.0"]
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
damiao = ["lerobot[can-dep]"]
robstride = ["lerobot[can-dep]"]
@@ -147,10 +151,11 @@ robstride = ["lerobot[can-dep]"]
openarms = ["lerobot[damiao]"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
lekiwi = ["lerobot[feetech]", "lerobot[pyzmq-dep]"]
unitree_g1 = [
# "unitree-sdk2==1.0.1",
"pyzmq>=26.2.1,<28.0.0",
"lerobot[pyzmq-dep]",
"lerobot[pyserial-dep]",
"onnxruntime>=1.16.0,<2.0.0",
"onnx>=1.16.0,<2.0.0",
"meshcat>=0.3.0,<0.4.0",
@@ -196,7 +201,8 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
@@ -33,7 +33,7 @@ import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk.media.camera import CameraView
@@ -76,6 +76,7 @@ class Reachy2Camera(Camera):
Args:
config: The configuration settings for the camera.
"""
require_package("reachy2_sdk", extra="reachy2")
super().__init__(config)
self.config = config
@@ -19,16 +19,18 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
import logging
import time
from threading import Event, Lock, Thread
from typing import Any
from typing import TYPE_CHECKING, Any
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
try:
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.import_utils import _pyrealsense2_available, require_package
if TYPE_CHECKING or _pyrealsense2_available:
import pyrealsense2 as rs
else:
rs = None
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
@@ -112,7 +114,7 @@ class RealSenseCamera(Camera):
Args:
config: The configuration settings for the camera.
"""
require_package("pyrealsense2", extra="intelrealsense")
super().__init__(config)
self.config = config
+11 -9
View File
@@ -28,12 +28,19 @@ import json
import logging
import time
from threading import Event, Lock, Thread
from typing import Any
from typing import TYPE_CHECKING, Any
import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.import_utils import _zmq_available, require_package
if TYPE_CHECKING or _zmq_available:
import zmq
else:
zmq = None
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
@@ -74,8 +81,8 @@ class ZMQCamera(Camera):
"""
def __init__(self, config: ZMQCameraConfig):
require_package("pyzmq", extra="pyzmq-dep", import_name="zmq")
super().__init__(config)
import zmq
self.config = config
self.server_address = config.server_address
@@ -117,8 +124,6 @@ class ZMQCamera(Camera):
logger.info(f"Connecting to {self}...")
try:
import zmq
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
@@ -180,11 +185,8 @@ class ZMQCamera(Camera):
try:
message = self.socket.recv_string()
except Exception as e:
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
except zmq.Again as e:
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
# Decode JSON message
data = json.loads(message)
+7 -4
View File
@@ -28,6 +28,12 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
from deepdiff import DeepDiff
else:
DeepDiff = None
if TYPE_CHECKING:
from lerobot.datasets import LeRobotDataset
@@ -217,10 +223,7 @@ def sanity_check_dataset_robot_compatibility(
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
from lerobot.utils.import_utils import require_package
require_package("deepdiff", extra="hardware")
from deepdiff import DeepDiff
require_package("deepdiff", extra="deepdiff-dep")
from lerobot.utils.constants import DEFAULT_FEATURES
+2 -2
View File
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
except BaseException:
dataset = kwargs.get("dataset")
writer = getattr(dataset, "writer", None) if dataset else None
if writer is not None and writer.image_writer is not None:
logger.warning("Waiting for image writer to terminate...")
writer.image_writer.stop()
raise e
raise
return wrapper
+12 -7
View File
@@ -12,8 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _placo_available, require_package
if TYPE_CHECKING or _placo_available:
import placo # type: ignore[import-not-found]
else:
placo = None
class RobotKinematics:
"""Robot kinematics using placo library for forward and inverse kinematics."""
@@ -32,13 +43,7 @@ class RobotKinematics:
target_frame_name (str): Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver
"""
try:
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
"Please install the optional dependencies of `kinematics` in the package."
) from e
require_package("placo", extra="placo-dep")
self.robot = placo.RobotWrapper(urdf_path)
self.solver = placo.KinematicsSolver(self.robot)
+2 -1
View File
@@ -24,7 +24,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
from lerobot.utils.import_utils import _can_available, require_package
if TYPE_CHECKING or _can_available:
import can
@@ -111,6 +111,7 @@ class DamiaoMotorsBus(MotorsBusBase):
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
require_package("python-can", extra="damiao", import_name="can")
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
+2 -2
View File
@@ -356,8 +356,8 @@ class SerialMotorsBus(MotorsBusBase):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
require_package("pyserial", extra="hardware", import_name="serial")
require_package("deepdiff", extra="hardware")
require_package("pyserial", extra="pyserial-dep", import_name="serial")
require_package("deepdiff", extra="deepdiff-dep")
super().__init__(port, motors, calibration)
self.port_handler: PortHandler
+3 -2
View File
@@ -23,12 +23,12 @@ from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
from lerobot.utils.import_utils import _can_available, require_package
if TYPE_CHECKING or _can_available:
import can
else:
can = SimpleNamespace(Message=object, interface=None)
can = SimpleNamespace(Message=object, interface=None, BusABC=object)
import numpy as np
from lerobot.utils.errors import DeviceNotConnectedError
@@ -106,6 +106,7 @@ class RobstrideMotorsBus(MotorsBusBase):
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
"""
require_package("python-can", extra="robstride", import_name="can")
super().__init__(port, motors, calibration)
self.port = port
self.can_interface = can_interface
+7 -3
View File
@@ -18,14 +18,21 @@ import logging
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.import_utils import _diffusers_available, require_package
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
if TYPE_CHECKING or _diffusers_available:
from diffusers.optimization import get_scheduler
else:
get_scheduler = None
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
@@ -47,10 +54,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="diffusion")
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
return get_scheduler(**kwargs)
@@ -23,6 +23,7 @@ TODO(alexander-soare):
import math
from collections import deque
from collections.abc import Callable
from typing import TYPE_CHECKING
import einops
import numpy as np
@@ -32,6 +33,14 @@ import torchvision
from torch import Tensor, nn
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
else:
DDIMScheduler = None
DDPMScheduler = None
from ..pretrained import PreTrainedPolicy
from ..utils import (
@@ -64,6 +73,7 @@ class DiffusionPolicy(PreTrainedPolicy):
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
require_package("diffusers", extra="diffusion")
super().__init__(config)
config.validate_features()
self.config = config
@@ -155,11 +165,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict):
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="diffusion")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if name == "DDPM":
return DDPMScheduler(**kwargs)
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
self.model.eval()
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=None,
**output_kwargs["images_kwargs"],
)
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
text = replace_in_text(text)
if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
+1
View File
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
@@ -43,6 +43,7 @@ from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from .configuration_groot import GrootConfig
@@ -59,6 +60,7 @@ class GrootPolicy(PreTrainedPolicy):
def __init__(self, config: GrootConfig, **kwargs):
"""Initialize Groot policy wrapper."""
require_package("transformers", extra="groot")
super().__init__(config)
config.validate_features()
self.config = config
@@ -36,7 +36,7 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor
from lerobot.utils.import_utils import _transformers_available
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
from .configuration_multi_task_dit import MultiTaskDiTConfig
@@ -46,6 +46,13 @@ if TYPE_CHECKING or _transformers_available:
else:
CLIPTextModel = None
CLIPVisionModel = None
if TYPE_CHECKING or _diffusers_available:
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
else:
DDIMScheduler = None
DDPMScheduler = None
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
@@ -65,6 +72,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
name = "multi_task_dit"
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
require_package("transformers", extra="multi_task_dit")
require_package("diffusers", extra="multi_task_dit")
super().__init__(config)
config.validate_features()
self.config = config
@@ -643,12 +652,6 @@ class DiffusionObjective(nn.Module):
"prediction_type": config.prediction_type,
}
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="multi_task_dit")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
+2 -1
View File
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _transformers_available
from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -947,6 +947,7 @@ class PI0Policy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
+2 -1
View File
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _transformers_available
from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -918,6 +918,7 @@ class PI05Policy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
@@ -26,7 +26,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.utils.import_utils import _scipy_available, _transformers_available
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _scipy_available:
@@ -35,7 +35,7 @@ else:
idct = None
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from ..pi_gemma import (
@@ -44,6 +44,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
AutoProcessor = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
@@ -826,14 +827,14 @@ class PI0FastPolicy(PreTrainedPolicy):
Args:
config: Policy configuration class instance.
"""
require_package("transformers", extra="pi")
require_package("scipy", extra="pi")
super().__init__(config)
config.validate_features()
self.config = config
# Load tokenizers first
try:
from transformers import AutoProcessor, AutoTokenizer
# Load FAST tokenizer
self.action_tokenizer = AutoProcessor.from_pretrained(
config.action_tokenizer_name, trust_remote_code=True
@@ -62,6 +62,7 @@ from torch import Tensor, nn
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.device_utils import get_safe_dtype
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..rtc.modeling_rtc import RTCProcessor
@@ -239,6 +240,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
the configuration class is used.
"""
require_package("transformers", extra="smolvla")
super().__init__(config)
config.validate_features()
self.config = config
+54 -58
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import functools
import threading
from collections.abc import Callable, Sequence
from contextlib import suppress
from typing import TypedDict
@@ -115,6 +116,7 @@ class ReplayBuffer:
self.size = 0
self.initialized = False
self.optimize_memory = optimize_memory
self._lock = threading.Lock()
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
@@ -198,68 +200,75 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
with self._lock:
# Initialize storage if this is the first transition
if not self.initialized:
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
# Store the transition in pre-allocated tensors
for key in self.states:
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
# Store the complementary_info
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
for key in self.complementary_info_keys:
if key in complementary_info:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
self.position = (self.position + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
with self._lock:
batch_size = min(batch_size, self.size)
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
# Random indices for sampling - create on the same device as storage
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
batch_next_state = {}
batch_state = {}
batch_next_state = {}
# First pass: load all state tensors to target device
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
for key in self.states:
batch_state[key] = self.states[key][idx].to(self.device)
if not self.optimize_memory:
# Standard approach - load next_states directly
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
# Memory-optimized approach - get next_state from the next index
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
if not self.optimize_memory:
batch_next_state[key] = self.next_states[key][idx].to(self.device)
else:
next_idx = (idx + 1) % self.capacity
batch_next_state[key] = self.states[key][next_idx].to(self.device)
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
# Apply image augmentation in a batched way if needed
if self.use_drq and image_keys:
# Concatenate all images from state and next_state
all_images = []
@@ -280,19 +289,6 @@ class ReplayBuffer:
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
if self.has_complementary_info:
batch_complementary_info = {}
for key in self.complementary_info_keys:
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
return BatchTransition(
state=batch_state,
action=batch_actions,
+2 -2
View File
@@ -551,8 +551,8 @@ def step_env_and_process_transition(
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
new_info = processed_action_transition[TransitionKey.INFO].copy()
new_info.update(info)
new_info = info.copy()
new_info.update(processed_action_transition[TransitionKey.INFO])
new_transition = create_transition(
observation=obs,
+2 -1
View File
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
from lerobot.cameras import make_cameras_from_configs
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import _reachy2_sdk_available
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -81,6 +81,7 @@ class Reachy2Robot(Robot):
name = "reachy2"
def __init__(self, config: Reachy2RobotConfig):
require_package("reachy2_sdk", extra="reachy2")
super().__init__(config)
self.config = config
+2 -1
View File
@@ -27,7 +27,7 @@ import numpy as np
from lerobot.cameras import make_cameras_from_configs
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import _unitree_sdk_available
from lerobot.utils.import_utils import _unitree_sdk_available, require_package
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
@@ -111,6 +111,7 @@ class UnitreeG1(Robot):
name = "unitree_g1"
def __init__(self, config: UnitreeG1Config):
require_package("unitree-sdk2py", extra="unitree_g1", import_name="unitree_sdk2py")
super().__init__(config)
logger.info("Initialize UnitreeG1...")
@@ -15,9 +15,22 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package
from ..utils import TeleopEvents
if TYPE_CHECKING or _pygame_available:
import pygame
else:
pygame = None # type: ignore[assignment]
if TYPE_CHECKING or _hidapi_available:
import hid
else:
hid = None # type: ignore[assignment]
class InputController:
"""Base class for input controllers that generate motion deltas."""
@@ -199,6 +212,7 @@ class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
require_package("pygame", extra="gamepad")
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
@@ -206,8 +220,6 @@ class GamepadController(InputController):
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
pygame.joystick.init()
@@ -230,8 +242,6 @@ class GamepadController(InputController):
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
if self.joystick:
self.joystick.quit()
@@ -240,8 +250,6 @@ class GamepadController(InputController):
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
@@ -280,8 +288,6 @@ class GamepadController(InputController):
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
@@ -326,6 +332,7 @@ class GamepadControllerHID(InputController):
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
"""
require_package("hidapi", extra="gamepad", import_name="hid")
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.device = None
@@ -342,8 +349,6 @@ class GamepadControllerHID(InputController):
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
device_name = device["product_string"]
@@ -357,8 +362,6 @@ class GamepadControllerHID(InputController):
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
if not self.device_info:
self.running = False
@@ -45,7 +45,7 @@ class HomunculusArm(Teleoperator):
name = "homunculus_arm"
def __init__(self, config: HomunculusArmConfig):
require_package("pyserial", extra="hardware", import_name="serial")
require_package("pyserial", extra="pyserial-dep", import_name="serial")
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
@@ -71,7 +71,7 @@ class HomunculusGlove(Teleoperator):
name = "homunculus_glove"
def __init__(self, config: HomunculusGloveConfig):
require_package("pyserial", extra="hardware", import_name="serial")
require_package("pyserial", extra="pyserial-dep", import_name="serial")
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
@@ -23,7 +23,7 @@ from typing import Any
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.import_utils import _pynput_available, require_package
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -56,6 +56,7 @@ class KeyboardTeleop(Teleoperator):
name = "keyboard"
def __init__(self, config: KeyboardTeleopConfig):
require_package("pynput", extra="pynput-dep")
super().__init__(config)
self.config = config
self.robot_type = config.type
@@ -21,14 +21,24 @@
import logging
import threading
import time
from typing import TYPE_CHECKING
import hebi
import numpy as np
from teleop import Teleop
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _hebi_available, _teleop_available, require_package
from lerobot.utils.rotation import Rotation
if TYPE_CHECKING or _hebi_available:
import hebi
else:
hebi = None
if TYPE_CHECKING or _teleop_available:
from teleop import Teleop
else:
Teleop = None
from ..teleoperator import Teleoperator
from .config_phone import PhoneConfig, PhoneOS
@@ -74,6 +84,8 @@ class IOSPhone(BasePhone, Teleoperator):
name = "ios_phone"
def __init__(self, config: PhoneConfig):
require_package("hebi-py", extra="phone", import_name="hebi")
require_package("teleop", extra="phone")
super().__init__(config)
self.config = config
self._group = None
@@ -213,6 +225,8 @@ class AndroidPhone(BasePhone, Teleoperator):
name = "android_phone"
def __init__(self, config: PhoneConfig):
require_package("hebi-py", extra="phone", import_name="hebi")
require_package("teleop", extra="phone")
super().__init__(config)
self.config = config
self._teleop = None
@@ -19,7 +19,7 @@ import logging
import time
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _reachy2_sdk_available
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk import ReachySDK
@@ -84,6 +84,7 @@ class Reachy2Teleoperator(Teleoperator):
name = "reachy2_specific"
def __init__(self, config: Reachy2TeleoperatorConfig):
require_package("reachy2_sdk", extra="reachy2")
super().__init__(config)
self.config = config
@@ -34,7 +34,7 @@ from typing import TYPE_CHECKING
import numpy as np
from lerobot.utils.import_utils import _serial_available
from lerobot.utils.import_utils import _serial_available, require_package
if TYPE_CHECKING or _serial_available:
import serial
@@ -156,6 +156,7 @@ def run_exo_calibration(
"""
Run interactive calibration for an exoskeleton arm.
"""
require_package("pyserial", extra="unitree_g1", import_name="serial")
try:
import cv2
import matplotlib.pyplot as plt
@@ -76,7 +76,7 @@ class ExoskeletonArm:
calibration: ExoskeletonCalibration | None = None
def __post_init__(self):
require_package("pyserial", extra="hardware", import_name="serial")
require_package("pyserial", extra="unitree_g1", import_name="serial")
if self.calibration_fpath.is_file():
self._load_calibration()
+6
View File
@@ -115,6 +115,12 @@ _feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_pyrealsense2_available = is_package_available("pyrealsense2")
_zmq_available = is_package_available("pyzmq", import_name="zmq")
_hebi_available = is_package_available("hebi-py", import_name="hebi")
_teleop_available = is_package_available("teleop")
_placo_available = is_package_available("placo")
_hidapi_available = is_package_available("hidapi", import_name="hid")
# Data / serialization
_pandas_available = is_package_available("pandas")
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval()
policy.reset() # Reset queues before inference
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
config.validate_features()
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.train()
# Use preprocessor to handle tokenization
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
policy.eval()
# Get device before saving
device = next(policy.parameters()).device
policy.save_pretrained(root)
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
# Explicitly move loaded_policy to the same device
loaded_policy.to(device)
loaded_policy.to(config.device)
loaded_policy.eval()
batch = create_train_batch(
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
with seeded_context(12):
# Process batch through preprocessor
processed_batch = preprocessor(batch)
# Move batch to the same device as the policy
for key in processed_batch:
if isinstance(processed_batch[key], torch.Tensor):
processed_batch[key] = processed_batch[key].to(device)
# Collect policy values before saving
loss, _ = policy.forward(processed_batch)
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
)
policy = MultiTaskDiTPolicy(config=config)
policy.to(config.device)
param_groups = policy.get_optim_params()
# Should have 2 parameter groups: non-vision and vision encoder
@@ -18,6 +18,11 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.utils.import_utils import is_package_available
if not is_package_available("reachy2_sdk"):
pytest.skip("reachy2_sdk not available", allow_module_level=True)
from lerobot.teleoperators.reachy2_teleoperator import (
REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS,
Generated
+970 -241
View File
File diff suppressed because it is too large Load Diff