mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 08:37:10 +00:00
Merge branch 'main' into feat/robotwin-benchmark
This commit is contained in:
@@ -2,11 +2,6 @@
|
||||
|
||||
Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). See [CONTRIBUTING.md](../CONTRIBUTING.md) for PR conventions.
|
||||
|
||||
## Type / Scope
|
||||
|
||||
- **Type**: (Bug | Feature | Docs | Performance | Test | CI | Chore)
|
||||
- **Scope**: (optional — name of module or package affected)
|
||||
|
||||
## Summary / Motivation
|
||||
|
||||
- One-paragraph description of what changes and why.
|
||||
@@ -19,28 +14,14 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S
|
||||
|
||||
## What changed
|
||||
|
||||
- Short, concrete bullets of the modifications (files/behaviour).
|
||||
- Short, concrete bullets explaining the functional changes (how the behavior or output differs now).
|
||||
- Short note if this introduces breaking changes and migration steps.
|
||||
|
||||
## How was this tested (or how to run locally)
|
||||
|
||||
- Tests added: list new tests or test files.
|
||||
- Tests added: list new tests or test files. `pytest -q tests/ -k <keyword>`
|
||||
- Manual checks / dataset runs performed.
|
||||
- Instructions for the reviewer
|
||||
|
||||
Example:
|
||||
|
||||
- Ran the relevant tests:
|
||||
|
||||
```bash
|
||||
pytest -q tests/ -k <keyword>
|
||||
```
|
||||
|
||||
- Reproduce with a quick example or CLI (if applicable):
|
||||
|
||||
```bash
|
||||
lerobot-train --some.option=true
|
||||
```
|
||||
- Instructions for the reviewer for reproducing with a quick example or CLI (if applicable)
|
||||
|
||||
## Checklist (required before merge)
|
||||
|
||||
@@ -48,6 +29,7 @@ Example:
|
||||
- [ ] All tests pass locally (`pytest`)
|
||||
- [ ] Documentation updated
|
||||
- [ ] CI is green
|
||||
- [ ] Community Review: I have reviewed another contributor's open PR and linked it here: # (insert PR number/link)
|
||||
|
||||
## Reviewer notes
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
+4
-1
@@ -78,6 +78,9 @@ Use the templates for required fields and examples.
|
||||
- **Issues:** Follow the [ticket template](https://github.com/huggingface/lerobot/blob/main/.github/ISSUE_TEMPLATE/bug-report.yml).
|
||||
- **Pull requests:** Rebase on `upstream/main`, use a descriptive branch (don't work on `main`), run `pre-commit` and tests locally, and follow the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md).
|
||||
|
||||
One member of the LeRobot team will then review your contribution.
|
||||
> [!IMPORTANT]
|
||||
> Community Review Policy: To help scale our efforts and foster a collaborative environment, we ask contributors to review at least one other person's open PR before their own receives attention. This shared responsibility multiplies our review capacity and helps everyone's code get merged faster!
|
||||
|
||||
Once you have submitted your PR and completed a peer review, a member of the LeRobot team will review your contribution.
|
||||
|
||||
Thank you for contributing to LeRobot!
|
||||
|
||||
@@ -32,6 +32,12 @@ Once you’ve gathered enough trajectories, you’ll train a neural network to i
|
||||
|
||||
If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support.
|
||||
|
||||
<Tip>
|
||||
|
||||
Want to quickly get the right commands for your setup? The [quickstart notebook](https://github.com/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) [](https://colab.research.google.com/github/huggingface/lerobot/blob/main/examples/notebooks/quickstart.ipynb) lets you configure your robot once and generates all the commands below ready to paste.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Set up and Calibrate
|
||||
|
||||
If you haven't yet set up and calibrated your robot and teleop device, please do so by following the robot-specific tutorial.
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 🤗 LeRobot Quickstart\n",
|
||||
"\n",
|
||||
"Calibration → teleoperation → data collection → training → evaluation.\n",
|
||||
"\n",
|
||||
"Install the required dependencies: `pip install -e .[notebook,dataset,training,viz,hardware]`.\n",
|
||||
"\n",
|
||||
"**How to use:**\n",
|
||||
"1. Edit the **Configuration** cell with your settings.\n",
|
||||
"2. Run all cells (`Run All`).\n",
|
||||
"3. Each section prints a ready-to-paste terminal command - copy it and run it.\n",
|
||||
"\n",
|
||||
"Each setup is different, please refer to the [LeRobot documentation](https://huggingface.co/docs/lerobot/il_robots) for more details on each step and available options. <br>\n",
|
||||
"Feel free to make this notebook your own and adapt it to your needs!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## Utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _cameras_arg(cameras: dict) -> str:\n",
|
||||
" if not cameras:\n",
|
||||
" return \"\"\n",
|
||||
" entries = [f\"{n}: {{{', '.join(f'{k}: {v}' for k, v in cfg.items())}}}\" for n, cfg in cameras.items()]\n",
|
||||
" return \"{ \" + \", \".join(entries) + \" }\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def print_cmd(*parts: str) -> None:\n",
|
||||
" \"\"\"Print a shell command with line continuations, skipping empty parts.\"\"\"\n",
|
||||
" non_empty = [p for p in parts if p]\n",
|
||||
" print(\" \\\\\\n \".join(non_empty))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## Configuration\n",
|
||||
"\n",
|
||||
"Edit this cell, then **Run All** to generate all commands below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Robot (follower) - run `lerobot-find-port` to discover the port\n",
|
||||
"ROBOT_TYPE = \"so101_follower\"\n",
|
||||
"ROBOT_PORT = \"/dev/ttyACM0\"\n",
|
||||
"ROBOT_ID = \"my_follower_arm\"\n",
|
||||
"\n",
|
||||
"# Teleop (leader) - run `lerobot-find-port` to discover the port\n",
|
||||
"TELEOP_TYPE = \"so101_leader\"\n",
|
||||
"TELEOP_PORT = \"/dev/ttyACM1\"\n",
|
||||
"TELEOP_ID = \"my_leader_arm\"\n",
|
||||
"\n",
|
||||
"# Cameras - set to {} to disable\n",
|
||||
"# Run `lerobot-find-cameras opencv` to list available cameras and their indices\n",
|
||||
"CAMERAS = {\n",
|
||||
" \"top\": {\"type\": \"opencv\", \"index_or_path\": 2, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
|
||||
" \"wrist\": {\"type\": \"opencv\", \"index_or_path\": 4, \"width\": 640, \"height\": 480, \"fps\": 30},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Dataset\n",
|
||||
"HF_USER = \"your_hf_username\" # `huggingface-cli whoami` to find your username\n",
|
||||
"DATASET_NAME = \"my_so101_dataset\"\n",
|
||||
"TASK_DESCRIPTION = \"pick and place the block\"\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
"\n",
|
||||
"# Training\n",
|
||||
"POLICY_TYPE = \"act\" # act, diffusion, smolvla, ...\n",
|
||||
"POLICY_DEVICE = \"cuda\" # cuda / cpu / mps\n",
|
||||
"TRAIN_STEPS = 10_000\n",
|
||||
"SAVE_FREQ = 2_000\n",
|
||||
"OUTPUT_DIR = f\"outputs/train/{DATASET_NAME}\"\n",
|
||||
"\n",
|
||||
"# Inference - Hub repo ID or local checkpoint path\n",
|
||||
"# e.g. set to f\"{OUTPUT_DIR}/checkpoints/last\" to use a local checkpoint\n",
|
||||
"POLICY_PATH = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
|
||||
"LAST_CHECKPOINT_PATH = f\"{OUTPUT_DIR}/checkpoints/last\"\n",
|
||||
"\n",
|
||||
"# Derived\n",
|
||||
"DATASET_REPO_ID = f\"{HF_USER}/{DATASET_NAME}\"\n",
|
||||
"DATASET_ROOT = f\"data/{DATASET_NAME}\"\n",
|
||||
"POLICY_REPO_ID = f\"{HF_USER}/{DATASET_NAME}_{POLICY_TYPE}\"\n",
|
||||
"EVAL_REPO_ID = f\"{HF_USER}/eval_{DATASET_NAME}\"\n",
|
||||
"CAMERAS_ARG = _cameras_arg(CAMERAS)\n",
|
||||
"CAMERAS_FLAG = f'--robot.cameras=\"{CAMERAS_ARG}\"' if CAMERAS_ARG else \"\"\n",
|
||||
"\n",
|
||||
"print(f\"Robot : {ROBOT_TYPE} @ {ROBOT_PORT}\")\n",
|
||||
"print(f\"Teleop : {TELEOP_TYPE} @ {TELEOP_PORT}\")\n",
|
||||
"print(f\"Cameras: {list(CAMERAS) or 'none'}\")\n",
|
||||
"print(f\"Dataset: {DATASET_REPO_ID} ({NUM_EPISODES} episodes) saved to {DATASET_ROOT}\")\n",
|
||||
"print(f\"Policy : {POLICY_TYPE} -> {POLICY_REPO_ID}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 1. Calibration\n",
|
||||
"\n",
|
||||
"Run once per arm before first use."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Follower\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-calibrate\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Leader\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-calibrate\",\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 2. Teleoperation\n",
|
||||
"\n",
|
||||
"See the [teleoperation docs](https://huggingface.co/docs/lerobot/il_robots#teleoperate) and the [cameras guide](https://huggingface.co/docs/lerobot/cameras) for more options."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-teleoperate\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 3. Record Dataset\n",
|
||||
"\n",
|
||||
"See the [recording docs](https://huggingface.co/docs/lerobot/il_robots#record-a-dataset) for tips on gathering good data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Resume a previously interrupted recording session\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--dataset.root={DATASET_ROOT}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
" \"--display_data=true\",\n",
|
||||
" \"--resume=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 4. Train Policy\n",
|
||||
"\n",
|
||||
"See the [training docs](https://huggingface.co/docs/lerobot/il_robots#train-a-policy) for configuration options and tips."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-train\",\n",
|
||||
" f\"--dataset.repo_id={DATASET_REPO_ID}\",\n",
|
||||
" f\"--policy.type={POLICY_TYPE}\",\n",
|
||||
" f\"--policy.device={POLICY_DEVICE}\",\n",
|
||||
" f\"--policy.repo_id={POLICY_REPO_ID}\",\n",
|
||||
" f\"--output_dir={OUTPUT_DIR}\",\n",
|
||||
" f\"--steps={TRAIN_STEPS}\",\n",
|
||||
" f\"--save_freq={SAVE_FREQ}\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Resume a previously interrupted training session\n",
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-train\",\n",
|
||||
" f\"--config_path={LAST_CHECKPOINT_PATH}/pretrained_model/train_config.json\",\n",
|
||||
" \"--resume=true\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"## 5. Inference\n",
|
||||
"\n",
|
||||
"Uses `POLICY_PATH` from the Configuration cell (defaults to the Hub repo ID). You can also put there the `LAST_CHECKPOINT_PATH`.\n",
|
||||
"\n",
|
||||
"See the [inference docs](https://huggingface.co/docs/lerobot/il_robots#run-inference-and-evaluate-your-policy) for details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print_cmd(\n",
|
||||
" \"lerobot-record\",\n",
|
||||
" f\"--policy.path={POLICY_PATH}\",\n",
|
||||
" f\"--robot.type={ROBOT_TYPE}\",\n",
|
||||
" f\"--robot.port={ROBOT_PORT}\",\n",
|
||||
" f\"--robot.id={ROBOT_ID}\",\n",
|
||||
" CAMERAS_FLAG,\n",
|
||||
" f\"--teleop.type={TELEOP_TYPE}\",\n",
|
||||
" f\"--teleop.port={TELEOP_PORT}\",\n",
|
||||
" f\"--teleop.id={TELEOP_ID}\",\n",
|
||||
" f\"--dataset.repo_id={EVAL_REPO_ID}\",\n",
|
||||
" f\"--dataset.num_episodes={NUM_EPISODES}\",\n",
|
||||
" f'--dataset.single_task=\"{TASK_DESCRIPTION}\"',\n",
|
||||
" \"--dataset.streaming_encoding=true\",\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot (3.12.3)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
+14
-8
@@ -108,9 +108,9 @@ training = [
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
"pynput>=1.7.8,<1.9.0",
|
||||
"pyserial>=3.5,<4.0",
|
||||
"deepdiff>=7.0.1,<9.0.0",
|
||||
"lerobot[pynput-dep]",
|
||||
"lerobot[pyserial-dep]",
|
||||
"lerobot[deepdiff-dep]",
|
||||
]
|
||||
viz = [
|
||||
"rerun-sdk>=0.24.0,<0.27.0",
|
||||
@@ -136,10 +136,14 @@ scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||
diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
|
||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
|
||||
pyserial-dep = ["pyserial>=3.5,<4.0"]
|
||||
deepdiff-dep = ["deepdiff>=7.0.1,<9.0.0"]
|
||||
pynput-dep = ["pynput>=1.7.8,<1.9.0"]
|
||||
pyzmq-dep = ["pyzmq>=26.2.1,<28.0.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0", "lerobot[pyserial-dep]", "lerobot[deepdiff-dep]"]
|
||||
damiao = ["lerobot[can-dep]"]
|
||||
robstride = ["lerobot[can-dep]"]
|
||||
|
||||
@@ -147,10 +151,11 @@ robstride = ["lerobot[can-dep]"]
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
lekiwi = ["lerobot[feetech]", "lerobot[pyzmq-dep]"]
|
||||
unitree_g1 = [
|
||||
# "unitree-sdk2==1.0.1",
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"lerobot[pyzmq-dep]",
|
||||
"lerobot[pyserial-dep]",
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"onnx>=1.16.0,<2.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
@@ -196,7 +201,8 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"]
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
@@ -76,6 +76,7 @@ class Reachy2Camera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -19,16 +19,18 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
try:
|
||||
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
from lerobot.utils.import_utils import _pyrealsense2_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _pyrealsense2_available:
|
||||
import pyrealsense2 as rs
|
||||
else:
|
||||
rs = None
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
@@ -112,7 +114,7 @@ class RealSenseCamera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
|
||||
require_package("pyrealsense2", extra="intelrealsense")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -28,12 +28,19 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.import_utils import _zmq_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _zmq_available:
|
||||
import zmq
|
||||
else:
|
||||
zmq = None
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
@@ -74,8 +81,8 @@ class ZMQCamera(Camera):
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZMQCameraConfig):
|
||||
require_package("pyzmq", extra="pyzmq-dep", import_name="zmq")
|
||||
super().__init__(config)
|
||||
import zmq
|
||||
|
||||
self.config = config
|
||||
self.server_address = config.server_address
|
||||
@@ -117,8 +124,6 @@ class ZMQCamera(Camera):
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
try:
|
||||
import zmq
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
@@ -180,11 +185,8 @@ class ZMQCamera(Camera):
|
||||
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# zmq is lazy-imported in connect(), so check by name to avoid a top-level import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
except zmq.Again as e:
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
|
||||
# Decode JSON message
|
||||
data = json.loads(message)
|
||||
|
||||
@@ -28,6 +28,12 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
|
||||
from lerobot.utils.import_utils import _deepdiff_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _deepdiff_available:
|
||||
from deepdiff import DeepDiff
|
||||
else:
|
||||
DeepDiff = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
@@ -217,10 +223,7 @@ def sanity_check_dataset_robot_compatibility(
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("deepdiff", extra="hardware")
|
||||
from deepdiff import DeepDiff
|
||||
require_package("deepdiff", extra="deepdiff-dep")
|
||||
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
|
||||
|
||||
@@ -30,13 +30,13 @@ def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
except BaseException:
|
||||
dataset = kwargs.get("dataset")
|
||||
writer = getattr(dataset, "writer", None) if dataset else None
|
||||
if writer is not None and writer.image_writer is not None:
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
writer.image_writer.stop()
|
||||
raise e
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -12,8 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _placo_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _placo_available:
|
||||
import placo # type: ignore[import-not-found]
|
||||
else:
|
||||
placo = None
|
||||
|
||||
|
||||
class RobotKinematics:
|
||||
"""Robot kinematics using placo library for forward and inverse kinematics."""
|
||||
@@ -32,13 +43,7 @@ class RobotKinematics:
|
||||
target_frame_name (str): Name of the end-effector frame in the URDF
|
||||
joint_names (list[str] | None): List of joint names to use for the kinematics solver
|
||||
"""
|
||||
try:
|
||||
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"placo is required for RobotKinematics. "
|
||||
"Please install the optional dependencies of `kinematics` in the package."
|
||||
) from e
|
||||
require_package("placo", extra="placo-dep")
|
||||
|
||||
self.robot = placo.RobotWrapper(urdf_path)
|
||||
self.solver = placo.KinematicsSolver(self.robot)
|
||||
|
||||
@@ -24,7 +24,7 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
from lerobot.utils.import_utils import _can_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
@@ -111,6 +111,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
require_package("python-can", extra="damiao", import_name="can")
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
|
||||
@@ -356,8 +356,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("deepdiff", extra="hardware")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
require_package("deepdiff", extra="deepdiff-dep")
|
||||
super().__init__(port, motors, calibration)
|
||||
|
||||
self.port_handler: PortHandler
|
||||
|
||||
@@ -23,12 +23,12 @@ from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
from lerobot.utils.import_utils import _can_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
can = SimpleNamespace(Message=object, interface=None)
|
||||
can = SimpleNamespace(Message=object, interface=None, BusABC=object)
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
@@ -106,6 +106,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
||||
bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps)
|
||||
data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False
|
||||
"""
|
||||
require_package("python-can", extra="robstride", import_name="can")
|
||||
super().__init__(port, motors, calibration)
|
||||
self.port = port
|
||||
self.can_interface = can_interface
|
||||
|
||||
@@ -18,14 +18,21 @@ import logging
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.utils.constants import SCHEDULER_STATE
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.optimization import get_scheduler
|
||||
else:
|
||||
get_scheduler = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -47,10 +54,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int | None = None
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="diffusion")
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
@@ -23,6 +23,7 @@ TODO(alexander-soare):
|
||||
import math
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -32,6 +33,14 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
else:
|
||||
DDIMScheduler = None
|
||||
DDPMScheduler = None
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import (
|
||||
@@ -64,6 +73,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
require_package("diffusers", extra="diffusion")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
@@ -155,11 +165,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict):
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="diffusion")
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
|
||||
@@ -204,7 +204,9 @@ class FlowmatchingActionHead(nn.Module):
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
@@ -249,7 +251,9 @@ class FlowmatchingActionHead(nn.Module):
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
if self._beta_dist is None:
|
||||
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
|
||||
@@ -222,6 +222,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
if isinstance(image_inputs["pixel_values"], list):
|
||||
_pv = image_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
image_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
@@ -233,6 +240,13 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
if isinstance(video_inputs["pixel_values"], list):
|
||||
_pv = video_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
video_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
@@ -288,8 +302,18 @@ class Eagle25VLProcessor(ProcessorMixin):
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
|
||||
|
||||
def _to_tensor(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
if v and isinstance(v[0], list):
|
||||
v = [t for sub in v for t in sub]
|
||||
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||
return torch.as_tensor(v)
|
||||
|
||||
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
|
||||
@@ -221,6 +221,7 @@ class GR00TN15(PreTrainedModel):
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
self.post_init()
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
|
||||
@@ -43,6 +43,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_groot import GrootConfig
|
||||
@@ -59,6 +60,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
def __init__(self, config: GrootConfig, **kwargs):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
require_package("transformers", extra="groot")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -36,7 +36,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
|
||||
|
||||
from .configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
|
||||
@@ -46,6 +46,13 @@ if TYPE_CHECKING or _transformers_available:
|
||||
else:
|
||||
CLIPTextModel = None
|
||||
CLIPVisionModel = None
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
else:
|
||||
DDIMScheduler = None
|
||||
DDPMScheduler = None
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
@@ -65,6 +72,8 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
|
||||
require_package("transformers", extra="multi_task_dit")
|
||||
require_package("diffusers", extra="multi_task_dit")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
@@ -643,12 +652,6 @@ class DiffusionObjective(nn.Module):
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="multi_task_dit")
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -947,6 +947,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
@@ -918,6 +918,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _scipy_available:
|
||||
@@ -35,7 +35,7 @@ else:
|
||||
idct = None
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from ..pi_gemma import (
|
||||
@@ -44,6 +44,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
PiGemmaModel = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
@@ -826,14 +827,14 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
Args:
|
||||
config: Policy configuration class instance.
|
||||
"""
|
||||
require_package("transformers", extra="pi")
|
||||
require_package("scipy", extra="pi")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Load tokenizers first
|
||||
try:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
# Load FAST tokenizer
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
config.action_tokenizer_name, trust_remote_code=True
|
||||
|
||||
@@ -62,6 +62,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.device_utils import get_safe_dtype
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
@@ -239,6 +240,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
the configuration class is used.
|
||||
"""
|
||||
|
||||
require_package("transformers", extra="smolvla")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
+54
-58
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import suppress
|
||||
from typing import TypedDict
|
||||
@@ -115,6 +116,7 @@ class ReplayBuffer:
|
||||
self.size = 0
|
||||
self.initialized = False
|
||||
self.optimize_memory = optimize_memory
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
@@ -198,68 +200,75 @@ class ReplayBuffer:
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
with self._lock:
|
||||
# Initialize storage if this is the first transition
|
||||
if not self.initialized:
|
||||
self._initialize_storage(state=state, action=action, complementary_info=complementary_info)
|
||||
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
# Store the transition in pre-allocated tensors
|
||||
for key in self.states:
|
||||
self.states[key][self.position].copy_(state[key].squeeze(dim=0))
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
# Store the complementary_info
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int | float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
for key in self.complementary_info_keys:
|
||||
if key in complementary_info:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int | float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
self.size = min(self.size + 1, self.capacity)
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
with self._lock:
|
||||
batch_size = min(batch_size, self.size)
|
||||
high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
batch_state = {}
|
||||
batch_next_state = {}
|
||||
|
||||
# First pass: load all state tensors to target device
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
for key in self.states:
|
||||
batch_state[key] = self.states[key][idx].to(self.device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Standard approach - load next_states directly
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
# Memory-optimized approach - get next_state from the next index
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
if not self.optimize_memory:
|
||||
batch_next_state[key] = self.next_states[key][idx].to(self.device)
|
||||
else:
|
||||
next_idx = (idx + 1) % self.capacity
|
||||
batch_next_state[key] = self.states[key][next_idx].to(self.device)
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
# Apply image augmentation in a batched way if needed
|
||||
if self.use_drq and image_keys:
|
||||
# Concatenate all images from state and next_state
|
||||
all_images = []
|
||||
@@ -280,19 +289,6 @@ class ReplayBuffer:
|
||||
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
if self.has_complementary_info:
|
||||
batch_complementary_info = {}
|
||||
for key in self.complementary_info_keys:
|
||||
batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device)
|
||||
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
|
||||
@@ -551,8 +551,8 @@ def step_env_and_process_transition(
|
||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||
new_info = processed_action_transition[TransitionKey.INFO].copy()
|
||||
new_info.update(info)
|
||||
new_info = info.copy()
|
||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -81,6 +81,7 @@ class Reachy2Robot(Robot):
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -27,7 +27,7 @@ import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available
|
||||
from lerobot.utils.import_utils import _unitree_sdk_available, require_package
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
@@ -111,6 +111,7 @@ class UnitreeG1(Robot):
|
||||
name = "unitree_g1"
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
require_package("unitree-sdk2py", extra="unitree_g1", import_name="unitree_sdk2py")
|
||||
super().__init__(config)
|
||||
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
@@ -15,9 +15,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING or _pygame_available:
|
||||
import pygame
|
||||
else:
|
||||
pygame = None # type: ignore[assignment]
|
||||
|
||||
if TYPE_CHECKING or _hidapi_available:
|
||||
import hid
|
||||
else:
|
||||
hid = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -199,6 +212,7 @@ class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
require_package("pygame", extra="gamepad")
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
@@ -206,8 +220,6 @@ class GamepadController(InputController):
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
@@ -230,8 +242,6 @@ class GamepadController(InputController):
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
if self.joystick:
|
||||
self.joystick.quit()
|
||||
@@ -240,8 +250,6 @@ class GamepadController(InputController):
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
@@ -280,8 +288,6 @@ class GamepadController(InputController):
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
@@ -326,6 +332,7 @@ class GamepadControllerHID(InputController):
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
require_package("hidapi", extra="gamepad", import_name="hid")
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
@@ -342,8 +349,6 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
@@ -357,8 +362,6 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
if not self.device_info:
|
||||
self.running = False
|
||||
|
||||
@@ -45,7 +45,7 @@ class HomunculusArm(Teleoperator):
|
||||
name = "homunculus_arm"
|
||||
|
||||
def __init__(self, config: HomunculusArmConfig):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
|
||||
@@ -71,7 +71,7 @@ class HomunculusGlove(Teleoperator):
|
||||
name = "homunculus_glove"
|
||||
|
||||
def __init__(self, config: HomunculusGloveConfig):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="pyserial-dep", import_name="serial")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -56,6 +56,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
name = "keyboard"
|
||||
|
||||
def __init__(self, config: KeyboardTeleopConfig):
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.robot_type = config.type
|
||||
|
||||
@@ -21,14 +21,24 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _hebi_available, _teleop_available, require_package
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
if TYPE_CHECKING or _hebi_available:
|
||||
import hebi
|
||||
else:
|
||||
hebi = None
|
||||
|
||||
if TYPE_CHECKING or _teleop_available:
|
||||
from teleop import Teleop
|
||||
else:
|
||||
Teleop = None
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_phone import PhoneConfig, PhoneOS
|
||||
|
||||
@@ -74,6 +84,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
name = "ios_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
require_package("hebi-py", extra="phone", import_name="hebi")
|
||||
require_package("teleop", extra="phone")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
@@ -213,6 +225,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
name = "android_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
require_package("hebi-py", extra="phone", import_name="hebi")
|
||||
require_package("teleop", extra="phone")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._teleop = None
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
from reachy2_sdk import ReachySDK
|
||||
@@ -84,6 +84,7 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
name = "reachy2_specific"
|
||||
|
||||
def __init__(self, config: Reachy2TeleoperatorConfig):
|
||||
require_package("reachy2_sdk", extra="reachy2")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -34,7 +34,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.import_utils import _serial_available
|
||||
from lerobot.utils.import_utils import _serial_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _serial_available:
|
||||
import serial
|
||||
@@ -156,6 +156,7 @@ def run_exo_calibration(
|
||||
"""
|
||||
Run interactive calibration for an exoskeleton arm.
|
||||
"""
|
||||
require_package("pyserial", extra="unitree_g1", import_name="serial")
|
||||
try:
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -76,7 +76,7 @@ class ExoskeletonArm:
|
||||
calibration: ExoskeletonCalibration | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("pyserial", extra="hardware", import_name="serial")
|
||||
require_package("pyserial", extra="unitree_g1", import_name="serial")
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
|
||||
@@ -115,6 +115,12 @@ _feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="
|
||||
_reachy2_sdk_available = is_package_available("reachy2_sdk")
|
||||
_can_available = is_package_available("python-can", "can")
|
||||
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
|
||||
_pyrealsense2_available = is_package_available("pyrealsense2")
|
||||
_zmq_available = is_package_available("pyzmq", import_name="zmq")
|
||||
_hebi_available = is_package_available("hebi-py", import_name="hebi")
|
||||
_teleop_available = is_package_available("teleop")
|
||||
_placo_available = is_package_available("placo")
|
||||
_hidapi_available = is_package_available("hidapi", import_name="hid")
|
||||
|
||||
# Data / serialization
|
||||
_pandas_available = is_package_available("pandas")
|
||||
|
||||
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
policy.eval()
|
||||
policy.reset() # Reset queues before inference
|
||||
|
||||
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
||||
config.validate_features()
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
policy.train()
|
||||
|
||||
# Use preprocessor to handle tokenization
|
||||
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
policy.eval()
|
||||
|
||||
# Get device before saving
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||
|
||||
# Explicitly move loaded_policy to the same device
|
||||
loaded_policy.to(device)
|
||||
loaded_policy.to(config.device)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_train_batch(
|
||||
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
||||
with seeded_context(12):
|
||||
# Process batch through preprocessor
|
||||
processed_batch = preprocessor(batch)
|
||||
# Move batch to the same device as the policy
|
||||
for key in processed_batch:
|
||||
if isinstance(processed_batch[key], torch.Tensor):
|
||||
processed_batch[key] = processed_batch[key].to(device)
|
||||
# Collect policy values before saving
|
||||
loss, _ = policy.forward(processed_batch)
|
||||
|
||||
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
|
||||
)
|
||||
|
||||
policy = MultiTaskDiTPolicy(config=config)
|
||||
policy.to(config.device)
|
||||
param_groups = policy.get_optim_params()
|
||||
|
||||
# Should have 2 parameter groups: non-vision and vision encoder
|
||||
|
||||
@@ -18,6 +18,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
if not is_package_available("reachy2_sdk"):
|
||||
pytest.skip("reachy2_sdk not available", allow_module_level=True)
|
||||
|
||||
from lerobot.teleoperators.reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
|
||||
Reference in New Issue
Block a user