mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(dependencies): require Python 3.12+ as minimum version (#3023)
* feat(dependecies): upgrade to python3.12 * fix(test): processor regex message * fix(test): processor regex message * fix(dependecies): resolve all tags in python 3.12 * fix(dependecies): add more hints to faster resolve * chore(dependecies): remove cli tag huggingface-hub dep * refactor(policy): update eagle for python3.12 * chore(docs): update policy creation for python 3.12 * chore(test): skip failing tests in macos
This commit is contained in:
@@ -44,7 +44,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
|
|
||||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||||
|
|
||||||
# Ensures that only the latest action is built, canceling older runs.
|
# Ensures that only the latest action is built, canceling older runs.
|
||||||
@@ -185,7 +185,7 @@ jobs:
|
|||||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||||
hf auth whoami
|
hf auth whoami
|
||||||
- name: Fix ptxas permissions
|
- name: Fix ptxas permissions
|
||||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
run: chmod +x /lerobot/.venv/lib/python3.12/site-packages/triton/backends/nvidia/bin/ptxas
|
||||||
- name: Run pytest on GPU
|
- name: Run pytest on GPU
|
||||||
run: pytest tests -vv --maxfail=10
|
run: pytest tests -vv --maxfail=10
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ on:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest
|
||||||
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Run pre-commit hooks
|
- name: Run pre-commit hooks
|
||||||
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ on:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# This job builds the Python package and publishes it to PyPI
|
# This job builds the Python package and publishes it to PyPI
|
||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v6
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.12'
|
||||||
|
|
||||||
- name: Extract Version
|
- name: Extract Version
|
||||||
id: extract_info
|
id: extract_info
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ permissions:
|
|||||||
# Sets up the environment variables
|
# Sets up the environment variables
|
||||||
env:
|
env:
|
||||||
UV_VERSION: "0.8.0"
|
UV_VERSION: "0.8.0"
|
||||||
PYTHON_VERSION: "3.10"
|
PYTHON_VERSION: "3.12"
|
||||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound
|
||||||
|
|
||||||
# Ensures that only the latest action is built, canceling older runs.
|
# Ensures that only the latest action is built, canceling older runs.
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.12
|
||||||
|
|
||||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ repos:
|
|||||||
rev: v3.21.0
|
rev: v3.21.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
args: [--py310-plus]
|
args: [--py312-plus]
|
||||||
|
|
||||||
##### Markdown Quality #####
|
##### Markdown Quality #####
|
||||||
- repo: https://github.com/rbubley/mirrors-prettier
|
- repo: https://github.com/rbubley/mirrors-prettier
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ ARG OS_VERSION=22.04
|
|||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||||
|
|
||||||
# Define Python version argument
|
# Define Python version argument
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
ENV DEBIAN_FRONTEND=noninteractive \
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
# docker run -it --rm lerobot-user
|
# docker run -it --rm lerobot-user
|
||||||
|
|
||||||
# Configure the base image
|
# Configure the base image
|
||||||
ARG PYTHON_VERSION=3.10
|
ARG PYTHON_VERSION=3.12
|
||||||
FROM python:${PYTHON_VERSION}-slim
|
FROM python:${PYTHON_VERSION}-slim
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
# your policy-specific dependencies
|
# your policy-specific dependencies
|
||||||
]
|
]
|
||||||
requires-python = ">= 3.11"
|
requires-python = ">= 3.12"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
build-backend = # your-build-backend
|
build-backend = # your-build-backend
|
||||||
@@ -82,7 +82,7 @@ Create your policy implementation by inheriting from LeRobot's base `PreTrainedP
|
|||||||
# modeling_my_custom_policy.py
|
# modeling_my_custom_policy.py
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||||
@@ -91,7 +91,7 @@ class MyCustomPolicy(PreTrainedPolicy):
|
|||||||
config_class = MyCustomPolicyConfig
|
config_class = MyCustomPolicyConfig
|
||||||
name = "my_custom_policy"
|
name = "my_custom_policy"
|
||||||
|
|
||||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
@@ -102,7 +102,7 @@ Create processor functions:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# processor_my_custom_policy.py
|
# processor_my_custom_policy.py
|
||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu
|
|||||||
### Hardware
|
### Hardware
|
||||||
|
|
||||||
- EarthRover Mini robot
|
- EarthRover Mini robot
|
||||||
- Computer with Python 3.10 or newer
|
- Computer with Python 3.12 or newer
|
||||||
- Internet connection
|
- Internet connection
|
||||||
|
|
||||||
### Setting Up the Frodobots SDK
|
### Setting Up the Frodobots SDK
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.12 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||||
|
|
||||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||||
|
|
||||||
@@ -11,10 +11,10 @@ bash Miniforge3-$(uname)-$(uname -m).sh
|
|||||||
|
|
||||||
## Step 2: Environment Setup
|
## Step 2: Environment Setup
|
||||||
|
|
||||||
Create a virtual environment with Python 3.10, using conda:
|
Create a virtual environment with Python 3.12, using conda:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
```
|
```
|
||||||
|
|
||||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ SSH into the robot and install LeRobot:
|
|||||||
```bash
|
```bash
|
||||||
ssh unitree@<YOUR_ROBOT_IP>
|
ssh unitree@<YOUR_ROBOT_IP>
|
||||||
|
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
cd lerobot
|
cd lerobot
|
||||||
@@ -153,7 +153,7 @@ With the robot server running, you can now control the robot remotely. Let's lau
|
|||||||
### Step 1: Install LeRobot on your machine
|
### Step 1: Install LeRobot on your machine
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10
|
conda create -y -n lerobot python=3.12
|
||||||
conda activate lerobot
|
conda activate lerobot
|
||||||
git clone https://github.com/huggingface/lerobot.git
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
cd lerobot
|
cd lerobot
|
||||||
|
|||||||
+31
-22
@@ -29,7 +29,7 @@ version = "0.4.5"
|
|||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
license = { text = "Apache-2.0" }
|
license = { text = "Apache-2.0" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.12"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
|
||||||
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
|
||||||
@@ -50,7 +50,8 @@ classifiers = [
|
|||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
"Topic :: Software Development :: Build Tools",
|
"Topic :: Software Development :: Build Tools",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
]
|
]
|
||||||
@@ -61,26 +62,28 @@ dependencies = [
|
|||||||
# Hugging Face dependencies
|
# Hugging Face dependencies
|
||||||
"datasets>=4.0.0,<5.0.0",
|
"datasets>=4.0.0,<5.0.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[cli]>=1.0.0,<2.0.0",
|
"huggingface-hub>=1.0.0,<2.0.0",
|
||||||
"accelerate>=1.10.0,<2.0.0",
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
|
"numpy>=2.0.0,<2.3.0", # TODO: upper bound imposed by opencv-python-headless
|
||||||
"setuptools>=71.0.0,<81.0.0",
|
"setuptools>=71.0.0,<81.0.0",
|
||||||
"cmake>=3.29.0.1,<4.2.0",
|
"cmake>=3.29.0.1,<4.2.0",
|
||||||
|
"packaging>=24.2,<26.0",
|
||||||
|
|
||||||
|
"torch>=2.2.1,<2.11.0",
|
||||||
|
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||||
|
"torchvision>=0.21.0,<0.26.0",
|
||||||
|
|
||||||
"einops>=0.8.0,<0.9.0",
|
"einops>=0.8.0,<0.9.0",
|
||||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||||
"av>=15.0.0,<16.0.0",
|
"av>=15.0.0,<16.0.0",
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
"packaging>=24.2,<26.0",
|
"pynput>=1.7.8,<1.9.0",
|
||||||
"pynput>=1.7.7,<1.9.0",
|
|
||||||
"pyserial>=3.5,<4.0",
|
"pyserial>=3.5,<4.0",
|
||||||
|
|
||||||
"wandb>=0.24.0,<0.25.0",
|
"wandb>=0.24.0,<0.25.0",
|
||||||
|
"draccus==0.10.0", # TODO: Relax version constraint
|
||||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
|
||||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
|
||||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
|
||||||
|
|
||||||
"draccus==0.10.0", # TODO: Remove ==
|
|
||||||
"gymnasium>=1.1.1,<2.0.0",
|
"gymnasium>=1.1.1,<2.0.0",
|
||||||
"rerun-sdk>=0.24.0,<0.27.0",
|
"rerun-sdk>=0.24.0,<0.27.0",
|
||||||
|
|
||||||
@@ -95,13 +98,14 @@ dependencies = [
|
|||||||
|
|
||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
|
||||||
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
|
||||||
|
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"]
|
||||||
|
|
||||||
# Motors
|
# Motors
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||||
@@ -119,7 +123,7 @@ unitree_g1 = [
|
|||||||
"onnxruntime>=1.16.0,<2.0.0",
|
"onnxruntime>=1.16.0,<2.0.0",
|
||||||
"pin>=3.0.0,<4.0.0",
|
"pin>=3.0.0,<4.0.0",
|
||||||
"meshcat>=0.3.0,<0.4.0",
|
"meshcat>=0.3.0,<0.4.0",
|
||||||
"matplotlib>=3.9.0,<4.0.0",
|
"lerobot[matplotlib-dep]",
|
||||||
"casadi>=3.6.0,<4.0.0",
|
"casadi>=3.6.0,<4.0.0",
|
||||||
]
|
]
|
||||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||||
@@ -128,7 +132,7 @@ intelrealsense = [
|
|||||||
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
|
||||||
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
|
||||||
]
|
]
|
||||||
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
|
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# Policies
|
# Policies
|
||||||
wallx = [
|
wallx = [
|
||||||
@@ -151,12 +155,12 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
@@ -165,13 +169,18 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
|||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
# Simulation
|
# Simulation
|
||||||
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
|
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||||
metaworld = ["metaworld==3.0.0"]
|
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||||
|
|
||||||
# All
|
# All
|
||||||
all = [
|
all = [
|
||||||
|
# Resolver hint: scipy is pulled in transitively via lerobot[scipy-dep] through
|
||||||
|
# multiple extras below (aloha, metaworld, pi, wallx, phone). Listing it explicitly
|
||||||
|
# helps pip's resolver converge by constraining scipy early, before it encounters
|
||||||
|
# the loose scipy requirements from transitive deps like dm-control and metaworld.
|
||||||
|
"scipy>=1.14.0,<2.0.0",
|
||||||
"lerobot[dynamixel]",
|
"lerobot[dynamixel]",
|
||||||
"lerobot[gamepad]",
|
"lerobot[gamepad]",
|
||||||
"lerobot[hopejr]",
|
"lerobot[hopejr]",
|
||||||
@@ -192,7 +201,7 @@ all = [
|
|||||||
"lerobot[aloha]",
|
"lerobot[aloha]",
|
||||||
"lerobot[pusht]",
|
"lerobot[pusht]",
|
||||||
"lerobot[phone]",
|
"lerobot[phone]",
|
||||||
"lerobot[libero]",
|
"lerobot[libero]; sys_platform == 'linux'",
|
||||||
"lerobot[metaworld]",
|
"lerobot[metaworld]",
|
||||||
"lerobot[sarm]",
|
"lerobot[sarm]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
@@ -224,7 +233,7 @@ lerobot = ["envs/*.json"]
|
|||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py312"
|
||||||
line-length = 110
|
line-length = 110
|
||||||
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||||
|
|
||||||
@@ -316,7 +325,7 @@ default.extend-ignore-identifiers-re = [
|
|||||||
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
python_version = "3.10"
|
python_version = "3.12"
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
follow_imports = "skip"
|
follow_imports = "skip"
|
||||||
# warn_return_any = true
|
# warn_return_any = true
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from collections import deque
|
|||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,8 +78,6 @@ DEFAULT_FEATURES = {
|
|||||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
}
|
}
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
@@ -1234,7 +1232,7 @@ class LookAheadError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Backtrackable(Generic[T]):
|
class Backtrackable[T]:
|
||||||
"""
|
"""
|
||||||
Wrap any iterator/iterable so you can step back up to `history` items
|
Wrap any iterator/iterable so you can step back up to `history` items
|
||||||
and look ahead up to `lookahead` items.
|
and look ahead up to `lookahead` items.
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Protocol, TypeAlias
|
from typing import Protocol
|
||||||
|
|
||||||
import serial
|
import serial
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
@@ -38,8 +38,8 @@ from tqdm import tqdm
|
|||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||||
|
|
||||||
NameOrID: TypeAlias = str | int
|
type NameOrID = str | int
|
||||||
Value: TypeAlias = int | float
|
type Value = int | float
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1277,4 +1277,4 @@ class SerialMotorsBus(MotorsBusBase):
|
|||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
# Backward compatibility alias
|
||||||
MotorsBus: TypeAlias = SerialMotorsBus
|
MotorsBus = SerialMotorsBus
|
||||||
|
|||||||
@@ -18,10 +18,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|||||||
@@ -4,10 +4,9 @@
|
|||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.image_processing_utils import (
|
from transformers.image_processing_utils import (
|
||||||
BatchFeature,
|
BatchFeature,
|
||||||
get_patch_output_size,
|
get_patch_output_size,
|
||||||
@@ -165,11 +164,11 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _resize_for_patching(
|
def _resize_for_patching(
|
||||||
self,
|
self,
|
||||||
image: "torch.Tensor",
|
image: torch.Tensor,
|
||||||
target_resolution: tuple,
|
target_resolution: tuple,
|
||||||
interpolation: "F.InterpolationMode",
|
interpolation: F.InterpolationMode,
|
||||||
input_data_format: ChannelDimension,
|
input_data_format: ChannelDimension,
|
||||||
) -> "torch.Tensor":
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||||
|
|
||||||
@@ -219,8 +218,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
return best_ratio
|
return best_ratio
|
||||||
|
|
||||||
def _pad_for_patching(
|
def _pad_for_patching(
|
||||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||||
) -> "torch.Tensor":
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pad an image to a target resolution while maintaining aspect ratio.
|
Pad an image to a target resolution while maintaining aspect ratio.
|
||||||
"""
|
"""
|
||||||
@@ -236,15 +235,15 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _get_image_patches(
|
def _get_image_patches(
|
||||||
self,
|
self,
|
||||||
image: "torch.Tensor",
|
image: torch.Tensor,
|
||||||
min_num: int,
|
min_num: int,
|
||||||
max_num: int,
|
max_num: int,
|
||||||
size: tuple,
|
size: tuple,
|
||||||
tile_size: int,
|
tile_size: int,
|
||||||
use_thumbnail: bool,
|
use_thumbnail: bool,
|
||||||
interpolation: "F.InterpolationMode",
|
interpolation: F.InterpolationMode,
|
||||||
pad_during_tiling: bool,
|
pad_during_tiling: bool,
|
||||||
) -> list["torch.Tensor"]:
|
) -> list[torch.Tensor]:
|
||||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||||
orig_height, orig_width = image_size
|
orig_height, orig_width = image_size
|
||||||
aspect_ratio = orig_width / orig_height
|
aspect_ratio = orig_width / orig_height
|
||||||
@@ -305,8 +304,8 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _pad_for_batching(
|
def _pad_for_batching(
|
||||||
self,
|
self,
|
||||||
pixel_values: list["torch.Tensor"],
|
pixel_values: list[torch.Tensor],
|
||||||
) -> list["torch.Tensor"]:
|
) -> list[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||||
|
|
||||||
@@ -327,14 +326,14 @@ class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
images: list["torch.Tensor"],
|
images: list[torch.Tensor],
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
max_dynamic_tiles: int,
|
max_dynamic_tiles: int,
|
||||||
min_dynamic_tiles: int,
|
min_dynamic_tiles: int,
|
||||||
use_thumbnail: bool,
|
use_thumbnail: bool,
|
||||||
pad_during_tiling: bool,
|
pad_during_tiling: bool,
|
||||||
interpolation: Optional["F.InterpolationMode"],
|
interpolation: F.InterpolationMode | None,
|
||||||
do_center_crop: bool,
|
do_center_crop: bool,
|
||||||
crop_size: SizeDict,
|
crop_size: SizeDict,
|
||||||
do_rescale: bool,
|
do_rescale: bool,
|
||||||
|
|||||||
@@ -20,12 +20,11 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,11 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -19,13 +19,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
from lerobot.utils.import_utils import _scipy_available, _transformers_available
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import os
|
|||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import TypedDict, TypeVar
|
from typing import TypedDict, TypeVar, Unpack
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
import safetensors
|
import safetensors
|
||||||
@@ -28,7 +28,6 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|||||||
@@ -54,12 +54,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TypedDict
|
from typing import TypedDict, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from typing_extensions import Unpack
|
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypeAlias, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -36,10 +36,10 @@ class TransitionKey(str, Enum):
|
|||||||
COMPLEMENTARY_DATA = "complementary_data"
|
COMPLEMENTARY_DATA = "complementary_data"
|
||||||
|
|
||||||
|
|
||||||
PolicyAction: TypeAlias = torch.Tensor
|
PolicyAction = torch.Tensor
|
||||||
RobotAction: TypeAlias = dict[str, Any]
|
RobotAction = dict[str, Any]
|
||||||
EnvAction: TypeAlias = np.ndarray
|
EnvAction = np.ndarray
|
||||||
RobotObservation: TypeAlias = dict[str, Any]
|
RobotObservation = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
EnvTransition = TypedDict(
|
EnvTransition = TypedDict(
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from collections.abc import Callable, Iterable, Sequence
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
from typing import Any, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -251,7 +251,7 @@ class ProcessorMigrationError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||||
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
"""A sequential pipeline for processing data, integrated with the Hugging Face Hub.
|
||||||
|
|
||||||
This class chains together multiple `ProcessorStep` instances to form a complete
|
This class chains together multiple `ProcessorStep` instances to form a complete
|
||||||
@@ -1432,8 +1432,8 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
|||||||
|
|
||||||
|
|
||||||
# Type aliases for semantic clarity.
|
# Type aliases for semantic clarity.
|
||||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
RobotProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput]
|
PolicyProcessorPipeline = DataProcessorPipeline[TInput, TOutput]
|
||||||
|
|
||||||
|
|
||||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
@@ -50,5 +49,5 @@ class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
SO100FollowerConfig = SOFollowerRobotConfig
|
||||||
SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig
|
SO101FollowerConfig = SOFollowerRobotConfig
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
@@ -230,5 +229,5 @@ class SOFollower(Robot):
|
|||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
|
||||||
SO100Follower: TypeAlias = SOFollower
|
SO100Follower = SOFollower
|
||||||
SO101Follower: TypeAlias = SOFollower
|
SO101Follower = SOFollower
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from ..config import TeleoperatorConfig
|
from ..config import TeleoperatorConfig
|
||||||
|
|
||||||
@@ -38,5 +37,5 @@ class SOLeaderTeleopConfig(TeleoperatorConfig, SOLeaderConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
SO100LeaderConfig: TypeAlias = SOLeaderTeleopConfig
|
SO100LeaderConfig = SOLeaderTeleopConfig
|
||||||
SO101LeaderConfig: TypeAlias = SOLeaderTeleopConfig
|
SO101LeaderConfig = SOLeaderTeleopConfig
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import TypeAlias
|
|
||||||
|
|
||||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
from lerobot.motors.feetech import (
|
from lerobot.motors.feetech import (
|
||||||
@@ -156,5 +155,5 @@ class SOLeader(Teleoperator):
|
|||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|
||||||
|
|
||||||
SO100Leader: TypeAlias = SOLeader
|
SO100Leader = SOLeader
|
||||||
SO101Leader: TypeAlias = SOLeader
|
SO101Leader = SOLeader
|
||||||
|
|||||||
@@ -16,12 +16,10 @@
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
|
|
||||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||||
T = TypeVar("T", bound=JsonLike)
|
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
@@ -33,7 +31,7 @@ def write_video(video_path, stacked_frames, fps):
|
|||||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T:
|
||||||
"""
|
"""
|
||||||
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
||||||
corresponding values (strictly matching structure and types).
|
corresponding values (strictly matching structure and types).
|
||||||
|
|||||||
@@ -143,12 +143,18 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
and for now we add tests as we see fit.
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
|
if policy_name == "vqbet" and DEVICE == "mps":
|
||||||
|
pytest.skip("VQBet does not support MPS backend")
|
||||||
|
if policy_name == "act" and "aloha" in ds_repo_id and DEVICE == "mps":
|
||||||
|
pytest.skip("ACT with aloha has batch mutation issues on MPS")
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
train_cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
env=make_env_config(env_name, **env_kwargs),
|
||||||
)
|
)
|
||||||
|
train_cfg.policy.device = DEVICE
|
||||||
train_cfg.validate()
|
train_cfg.validate()
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
@@ -227,6 +233,7 @@ def test_act_backbone_lr():
|
|||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
||||||
)
|
)
|
||||||
|
cfg.policy.device = DEVICE
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
assert cfg.policy.optimizer_lr == 0.01
|
assert cfg.policy.optimizer_lr == 0.01
|
||||||
|
|||||||
@@ -1870,9 +1870,7 @@ class NonCallableStep(ProcessorStep):
|
|||||||
|
|
||||||
def test_construction_rejects_step_without_call():
|
def test_construction_rejects_step_without_call():
|
||||||
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
"""Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep."""
|
||||||
with pytest.raises(
|
with pytest.raises(TypeError, match=r"Can't instantiate abstract class NonCallableStep"):
|
||||||
TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_"
|
|
||||||
):
|
|
||||||
DataProcessorPipeline([NonCallableStep()])
|
DataProcessorPipeline([NonCallableStep()])
|
||||||
|
|
||||||
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"):
|
||||||
|
|||||||
+2
-1
@@ -22,8 +22,9 @@ import torch
|
|||||||
|
|
||||||
from lerobot import available_cameras, available_motors, available_robots
|
from lerobot import available_cameras, available_motors, available_robots
|
||||||
from lerobot.utils.import_utils import is_package_available
|
from lerobot.utils.import_utils import is_package_available
|
||||||
|
from lerobot.utils.utils import auto_select_torch_device
|
||||||
|
|
||||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
|
||||||
|
|
||||||
TEST_ROBOT_TYPES = []
|
TEST_ROBOT_TYPES = []
|
||||||
for robot_type in available_robots:
|
for robot_type in available_robots:
|
||||||
|
|||||||
Reference in New Issue
Block a user